Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17    Stmt, Token, Type,
18};
19
20// =============================================================================
21// Input types
22// =============================================================================
23
24mod kw {
25    syn::custom_keyword!(scalar);
26}
27
28/// A `:name = expr` parameter binding.
29#[derive(Clone)]
30struct ParamAssign {
31    name: Ident,
32    expr: Expr,
33}
34
35impl Parse for ParamAssign {
36    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37        input.parse::<Token![:]>()?;
38        let name: Ident = input.parse()?;
39        input.parse::<Token![=]>()?;
40        let expr: Expr = input.parse()?;
41        Ok(Self { name, expr })
42    }
43}
44
45/// A single section value: SQL string + optional local parameter bindings.
46#[derive(Clone)]
47struct SectionFragment {
48    sql: String,
49    span: Span,
50    params: ParamsSource,
51}
52
53/// One arm in a `match { ... }` inside a section assignment.
54#[derive(Clone)]
55struct SectionMatchArm {
56    pat: Pat,
57    guard: Option<Expr>,
58    value: SectionValue,
59}
60
61/// The right-hand side of a section `#name = ...` assignment.
62#[derive(Clone)]
63enum SectionValue {
64    /// A plain string (or `(string, params)` tuple).
65    Single(SectionFragment),
66    /// A tuple of values, one per section when using `#(a, b) = ...`.
67    Grouped(Vec<SectionValue>),
68    /// A `match expr { arm => ..., arm => ... }` expression.
69    Match {
70        expr: Expr,
71        arms: Vec<SectionMatchArm>,
72    },
73}
74
75/// A full `#name = value` or `#(a, b) = value` section assignment.
76#[derive(Clone)]
77struct SectionAssign {
78    names: Vec<Ident>,
79    value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        input.parse::<Token![#]>()?;
85
86        // Parse `#(a, b)` for grouped sections, or `#name` for single.
87        let names = if input.peek(syn::token::Paren) {
88            let content;
89            syn::parenthesized!(content in input);
90            let mut out = Vec::new();
91            while !content.is_empty() {
92                out.push(content.parse::<Ident>()?);
93                if content.is_empty() {
94                    break;
95                }
96                content.parse::<Token![,]>()?;
97            }
98            if out.is_empty() {
99                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100            }
101            out
102        } else {
103            vec![input.parse::<Ident>()?]
104        };
105
106        input.parse::<Token![=]>()?;
107        let value = parse_section_value(input, names.len())?;
108        Ok(Self { names, value })
109    }
110}
111
112/// The fully-parsed macro invocation.
113struct SqlForgeInput {
114    db: Option<Type>,
115    result: ResultSpec,
116    force_scalar: bool,
117    sql: SqlTemplate,
118    params: ParamsSource,
119    sections: Vec<SectionAssign>,
120    batch: Option<Expr>,
121}
122
123/// One entry in a result map `(>key = Model, ...)`.
124#[derive(Clone)]
125struct ResultAssign {
126    name: Ident,
127    model: Type,
128    force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
134    None,
135    /// e.g. `sql_forge!(User, ...)`
136    Single(Box<Type>),
137    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
138    Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143    None,
144    /// `( :name = expr, ... )`
145    Map(Vec<ParamAssign>),
146    /// A struct expression whose fields are matched to `:name` placeholders.
147    Struct(Box<Expr>),
148}
149
150/// The SQL template. Only string literals are supported.
151enum SqlTemplate {
152    Literal(LitStr),
153}
154
155impl SqlTemplate {
156    fn span(&self) -> Span {
157        match self {
158            Self::Literal(lit) => lit.span(),
159        }
160    }
161
162    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
163    fn into_segments(self) -> Result<Vec<Segment>, String> {
164        match self {
165            Self::Literal(lit) => parse_literal_segments(&lit.value()),
166        }
167    }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171    if input.peek(LitStr) {
172        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173    } else {
174        Err(input.error("sql_forge!: SQL template must be a string literal"))
175    }
176}
177
178/// One piece of the parsed SQL template.
179#[derive(Clone)]
180enum Segment {
181    /// Plain SQL text (may contain `:param` placeholders).
182    Text(String),
183    /// A `{#section_name}` placeholder.
184    Section { name: String },
185    /// A `{( ... )}` batch value template (repeated per batch item).
186    Batch { parts: Vec<TextPart> },
187}
188
189/// A fragment of SQL text after splitting on `:param`.
190#[derive(Clone)]
191enum TextPart {
192    /// Literal SQL text.
193    Lit(String),
194    /// A `:param` placeholder.
195    Param { name: String, is_list: bool },
196}
197
198// =============================================================================
199// Parsing helpers
200// =============================================================================
201
202/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
203enum MapKind {
204    Results,
205    Params,
206    Sections,
207}
208
209// =============================================================================
210// Section value parsing (string fragments, match, grouped tuples)
211// =============================================================================
212
213fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
214    let fork = input.fork();
215    let content;
216    syn::parenthesized!(content in fork);
217
218    if content.is_empty() {
219        return Err(input.error("sql_forge!: map argument cannot be empty"));
220    }
221
222    if content.peek(Token![>]) {
223        Ok(Some(MapKind::Results))
224    } else if content.peek(Token![:]) {
225        Ok(Some(MapKind::Params))
226    } else if content.peek(Token![#]) {
227        Ok(Some(MapKind::Sections))
228    } else {
229        Ok(None)
230    }
231}
232
233impl Parse for ResultAssign {
234    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
235        input.parse::<Token![>]>()?;
236        let name: Ident = input.parse()?;
237        input.parse::<Token![=]>()?;
238        let (force_scalar, model) = if input.peek(kw::scalar) {
239            input.parse::<kw::scalar>()?;
240            (true, input.parse::<Type>()?)
241        } else {
242            (false, input.parse::<Type>()?)
243        };
244        Ok(Self {
245            name,
246            model,
247            force_scalar,
248        })
249    }
250}
251
252fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
253    let content;
254    syn::parenthesized!(content in input);
255
256    let mut results = Vec::new();
257    while !content.is_empty() {
258        results.push(content.parse::<ResultAssign>()?);
259        if content.is_empty() {
260            break;
261        }
262        content.parse::<Token![,]>()?;
263    }
264
265    if results.is_empty() {
266        return Err(input.error("sql_forge!: result map cannot be empty"));
267    }
268
269    Ok(results)
270}
271
272fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
273    let content;
274    syn::parenthesized!(content in input);
275
276    let mut params = Vec::new();
277    while !content.is_empty() {
278        params.push(content.parse::<ParamAssign>()?);
279        if content.is_empty() {
280            break;
281        }
282        content.parse::<Token![,]>()?;
283    }
284
285    Ok(params)
286}
287
288fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
289    let content;
290    syn::parenthesized!(content in input);
291
292    let mut sections = Vec::new();
293    while !content.is_empty() {
294        sections.push(content.parse::<SectionAssign>()?);
295        if content.is_empty() {
296            break;
297        }
298        content.parse::<Token![,]>()?;
299    }
300
301    Ok(sections)
302}
303
304fn parse_params_source_expr(
305    input: ParseStream<'_>,
306    allow_sections: bool,
307) -> syn::Result<ParamsSource> {
308    if input.peek(syn::token::Paren) {
309        match detect_parenthesized_map_kind(input)? {
310            Some(MapKind::Results) => Err(input
311                .error("sql_forge!: result maps are only allowed as the macro result argument")),
312            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
313            Some(MapKind::Sections) if allow_sections => {
314                Err(input.error("sql_forge!: section maps are not allowed here"))
315            }
316            Some(MapKind::Sections) => Err(input.error(
317                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
318            )),
319            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
320        }
321    } else {
322        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
323    }
324}
325
326fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327    if input.peek(syn::token::Paren) {
328        let fork = input.fork();
329        let content;
330        syn::parenthesized!(content in fork);
331
332        if let Ok(first_expr) = content.parse::<Expr>() {
333            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334                let _ = parse_params_source_expr(&content, false)?;
335                if content.peek(Token![,]) {
336                    content.parse::<Token![,]>()?;
337                }
338                if content.is_empty() {
339                    let content;
340                    syn::parenthesized!(content in input);
341                    let first_expr: Expr = content.parse()?;
342                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343                        input.error("sql_forge!: section tuple must start with a string literal")
344                    })?;
345                    let span = first_expr.span();
346                    content.parse::<Token![,]>()?;
347                    let params = parse_params_source_expr(&content, false)?;
348                    if content.peek(Token![,]) {
349                        content.parse::<Token![,]>()?;
350                    }
351                    if !content.is_empty() {
352                        return Err(content.error(
353                            "sql_forge!: unexpected tokens after section-local parameter source",
354                        ));
355                    }
356                    return Ok(SectionFragment { sql, span, params });
357                }
358            }
359        }
360    }
361
362    let expr: Expr = input.parse()?;
363    let sql = extract_lit_str(&expr).ok_or_else(|| {
364        input
365            .error("sql_forge!: section values must be string literals or (string literal, params)")
366    })?;
367    Ok(SectionFragment {
368        sql,
369        span: expr.span(),
370        params: ParamsSource::None,
371    })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
375    if input.peek(Token![match]) {
376        input.parse::<Token![match]>()?;
377        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
378        let content;
379        syn::braced!(content in input);
380        let mut arms = Vec::new();
381        while !content.is_empty() {
382            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
383            let guard = if content.peek(Token![if]) {
384                content.parse::<Token![if]>()?;
385                Some(content.parse::<Expr>()?)
386            } else {
387                None
388            };
389            content.parse::<Token![=>]>()?;
390            let value = parse_section_value(&content, width)?;
391            if content.peek(Token![,]) {
392                content.parse::<Token![,]>()?;
393            }
394            arms.push(SectionMatchArm { pat, guard, value });
395        }
396        return Ok(SectionValue::Match { expr, arms });
397    }
398
399    if width == 1 {
400        return Ok(SectionValue::Single(parse_section_fragment(input)?));
401    }
402
403    let content;
404    syn::parenthesized!(content in input);
405    let mut items = Vec::new();
406    while !content.is_empty() {
407        items.push(parse_section_value(&content, 1)?);
408        if content.is_empty() {
409            break;
410        }
411        content.parse::<Token![,]>()?;
412    }
413
414    if items.len() != width {
415        return Err(input.error(format!(
416            "sql_forge!: grouped section value must provide exactly {} items",
417            width,
418        )));
419    }
420
421    Ok(SectionValue::Grouped(items))
422}
423
424// =============================================================================
425// Top-level macro input parsing (SqlForgeInput::parse)
426// =============================================================================
427
428impl Parse for SqlForgeInput {
429    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
430        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
431            let sql = parse_sql_template(input)?;
432            (None, ResultSpec::None, false, sql)
433        } else if input.peek(kw::scalar) {
434            input.parse::<kw::scalar>()?;
435            let model: Type = input.parse()?;
436            input.parse::<Token![,]>()?;
437            let sql = parse_sql_template(input)?;
438            (None, ResultSpec::Single(Box::new(model)), true, sql)
439        } else if input.peek(syn::token::Paren) {
440            let result_map_kind = detect_parenthesized_map_kind(input)?;
441            match result_map_kind {
442                Some(MapKind::Results) => {
443                    let result = ResultSpec::Group(parse_result_map(input)?);
444                    input.parse::<Token![,]>()?;
445                    let sql = parse_sql_template(input)?;
446                    (None, result, false, sql)
447                }
448                _ => {
449                    return Err(input.error(
450                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
451                    ));
452                }
453            }
454        } else {
455            let first_ty: Type = input.parse()?;
456            input.parse::<Token![,]>()?;
457
458            if input.peek(LitStr) {
459                let model = first_ty;
460                let sql = parse_sql_template(input)?;
461                (None, ResultSpec::Single(Box::new(model)), false, sql)
462            } else if input.peek(kw::scalar) {
463                input.parse::<kw::scalar>()?;
464                let model: Type = input.parse()?;
465                input.parse::<Token![,]>()?;
466                let sql = parse_sql_template(input)?;
467                (
468                    Some(first_ty),
469                    ResultSpec::Single(Box::new(model)),
470                    true,
471                    sql,
472                )
473            } else if input.peek(syn::token::Paren)
474                && matches!(
475                    detect_parenthesized_map_kind(input)?,
476                    Some(MapKind::Results)
477                )
478            {
479                let result = ResultSpec::Group(parse_result_map(input)?);
480                input.parse::<Token![,]>()?;
481                let sql = parse_sql_template(input)?;
482                (Some(first_ty), result, false, sql)
483            } else {
484                let db = Some(first_ty);
485                let model: Type = input.parse()?;
486                input.parse::<Token![,]>()?;
487                let sql = parse_sql_template(input)?;
488                (db, ResultSpec::Single(Box::new(model)), false, sql)
489            }
490        };
491
492        let mut batch = None;
493        let mut params = ParamsSource::None;
494        let mut sections = Vec::new();
495        let mut seen_params = false;
496        let mut seen_sections = false;
497
498        if input.parse::<Token![,]>().is_ok() {
499            while !input.is_empty() {
500                if input.peek(Token![..]) {
501                    if batch.is_some() {
502                        return Err(
503                            input.error("sql_forge!: only one batch source argument is allowed")
504                        );
505                    }
506                    input.parse::<Token![..]>()?;
507                    batch = Some(input.parse::<Expr>()?);
508                } else if input.peek(syn::token::Paren) {
509                    match detect_parenthesized_map_kind(input)? {
510                        Some(MapKind::Results) => {
511                            return Err(input.error(
512                                "sql_forge!: result maps are only allowed as the macro result argument",
513                            ));
514                        }
515                        Some(MapKind::Params) => {
516                            if seen_params {
517                                return Err(
518                                    input.error("sql_forge!: only one parameter source is allowed")
519                                );
520                            }
521                            params = ParamsSource::Map(parse_param_map(input)?);
522                            seen_params = true;
523                        }
524                        Some(MapKind::Sections) => {
525                            if seen_sections {
526                                return Err(
527                                    input.error("sql_forge!: duplicate section map argument")
528                                );
529                            }
530                            sections = parse_section_map(input)?;
531                            seen_sections = true;
532                        }
533                        None => {
534                            if seen_params {
535                                return Err(
536                                    input.error("sql_forge!: only one parameter source is allowed")
537                                );
538                            }
539                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
540                            seen_params = true;
541                        }
542                    }
543                } else {
544                    if seen_params {
545                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
546                    }
547                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548                    seen_params = true;
549                }
550
551                if input.parse::<Token![,]>().is_ok() {
552                    continue;
553                }
554                break;
555            }
556        }
557
558        if !input.is_empty() {
559            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
560        }
561
562        Ok(Self {
563            db,
564            result,
565            force_scalar,
566            sql,
567            params,
568            sections,
569            batch,
570        })
571    }
572}
573
574// =============================================================================
575// Database type resolution
576// =============================================================================
577
578fn resolve_db_from_env() -> Result<Type, String> {
579    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
580        return syn::parse_str::<Type>(&val).map_err(|err| {
581            format!(
582                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
583                val, err
584            )
585        });
586    }
587
588    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
589        Ok(d) => d,
590        Err(_) => {
591            return Err(
592                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
593                 or configure [package.metadata.sql_forge] in Cargo.toml"
594                    .to_string(),
595            );
596        }
597    };
598    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
599
600    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
601        format!(
602            "sql_forge!: failed to read {}: {}",
603            manifest_path.display(),
604            err
605        )
606    })?;
607
608    let value: toml::Value = toml::from_str(&cargo_toml)
609        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
610
611    let db_str = value
612        .get("package")
613        .and_then(|v| v.get("metadata"))
614        .and_then(|v| v.get("sql_forge"))
615        .and_then(|v| v.get("db"))
616        .and_then(|v| v.as_str())
617        .ok_or({
618            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
619             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
620        })?;
621
622    syn::parse_str::<Type>(db_str).map_err(|err| {
623        format!(
624            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
625            db_str, err
626        )
627    })
628}
629
630fn uses_dollar_params(db: &Type) -> bool {
631    let Type::Path(type_path) = db else {
632        return false;
633    };
634    type_path
635        .path
636        .segments
637        .last()
638        .is_some_and(|s| s.ident == "Postgres")
639}
640
641fn is_builtin_scalar_type(ty: &Type) -> bool {
642    let Type::Path(type_path) = ty else {
643        return false;
644    };
645
646    if type_path.qself.is_some()
647        || type_path.path.leading_colon.is_some()
648        || type_path.path.segments.len() != 1
649    {
650        return false;
651    }
652
653    let ident = &type_path.path.segments[0].ident;
654    ident == "i8"
655        || ident == "i16"
656        || ident == "i32"
657        || ident == "i64"
658        || ident == "isize"
659        || ident == "u8"
660        || ident == "u16"
661        || ident == "u32"
662        || ident == "u64"
663        || ident == "usize"
664        || ident == "f32"
665        || ident == "f64"
666        || ident == "bool"
667        || ident == "String"
668}
669
670fn scalar_output_type(model: &Type) -> Option<&Type> {
671    if is_builtin_scalar_type(model) {
672        return Some(model);
673    }
674    None
675}
676
677fn push_text_segment(out: &mut Vec<Segment>, text: String) {
678    if text.is_empty() {
679        return;
680    }
681    match out.last_mut() {
682        Some(Segment::Text(existing)) => existing.push_str(&text),
683        _ => out.push(Segment::Text(text)),
684    }
685}
686
687fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
688    let mut out = Vec::new();
689    let mut text = String::new();
690    let mut chars = sql.chars().peekable();
691
692    while let Some(ch) = chars.next() {
693        if ch != '{' {
694            text.push(ch);
695            continue;
696        }
697
698        if chars.peek() == Some(&'(') {
699            push_text_segment(&mut out, std::mem::take(&mut text));
700
701            let mut paren_depth = 0u32;
702            let mut content = String::new();
703            let mut found_close = false;
704            for ch in chars.by_ref() {
705                if ch == '{' {
706                    return Err(
707                        "sql_forge!: nested braces not allowed inside batch section".to_string()
708                    );
709                }
710                if ch == '}' {
711                    if paren_depth != 0 {
712                        return Err(
713                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
714                                .to_string(),
715                        );
716                    }
717                    found_close = true;
718                    break;
719                }
720                if ch == '(' {
721                    paren_depth += 1;
722                } else if ch == ')' {
723                    if paren_depth == 0 {
724                        return Err(
725                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
726                                .to_string(),
727                        );
728                    }
729                    paren_depth -= 1;
730                }
731                content.push(ch);
732            }
733            if !found_close {
734                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
735            }
736            let parts = parse_text_parts(&content);
737            for part in &parts {
738                if let TextPart::Param { is_list: true, .. } = part {
739                    return Err(
740                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
741                         batch sections; use plain parameters (:name) instead"
742                            .to_string(),
743                    );
744                }
745            }
746            out.push(Segment::Batch { parts });
747            continue;
748        }
749
750        if chars.peek() != Some(&'#') {
751            text.push(ch);
752            continue;
753        }
754
755        chars.next();
756        push_text_segment(&mut out, std::mem::take(&mut text));
757
758        let mut name = String::new();
759        loop {
760            let Some(next) = chars.next() else {
761                return Err("sql_forge!: section placeholder without closing }".to_string());
762            };
763            if next == '}' {
764                break;
765            }
766            name.push(next);
767        }
768
769        if name.is_empty() {
770            return Err("sql_forge!: empty section placeholder name".to_string());
771        }
772
773        out.push(Segment::Section { name });
774    }
775
776    push_text_segment(&mut out, text);
777    Ok(out)
778}
779
780// =============================================================================
781// SQL template parsing: {#sections} and :param placeholders
782// =============================================================================
783
784fn is_ident_start(ch: char) -> bool {
785    ch == '_' || ch.is_ascii_alphabetic()
786}
787
788fn is_ident_continue(ch: char) -> bool {
789    is_ident_start(ch) || ch.is_ascii_digit()
790}
791
792fn sanitize_backticked_alias_ident(content: &str) -> String {
793    let mut split_at = content.len();
794    for (idx, ch) in content.char_indices() {
795        if ch == '!' || ch == '?' || ch == ':' {
796            split_at = idx;
797            break;
798        }
799    }
800
801    if split_at == content.len() {
802        return content.to_string();
803    }
804
805    let base = content[..split_at].trim_end();
806    if base.is_empty() {
807        content.to_string()
808    } else {
809        base.to_string()
810    }
811}
812
813fn sanitize_runtime_sql_text(text: &str) -> String {
814    let mut out = String::with_capacity(text.len());
815    let mut chars = text.chars().peekable();
816
817    while let Some(ch) = chars.next() {
818        if ch != '`' {
819            out.push(ch);
820            continue;
821        }
822
823        let mut content = String::new();
824        let mut closed = false;
825
826        for next in chars.by_ref() {
827            if next == '`' {
828                closed = true;
829                break;
830            }
831            content.push(next);
832        }
833
834        if closed {
835            out.push('`');
836            out.push_str(&sanitize_backticked_alias_ident(&content));
837            out.push('`');
838        } else {
839            out.push('`');
840            out.push_str(&content);
841            break;
842        }
843    }
844
845    out
846}
847
848fn parse_text_parts(text: &str) -> Vec<TextPart> {
849    let mut parts = Vec::new();
850    let mut last = 0usize;
851    let mut iter = text.char_indices().peekable();
852
853    while let Some((idx, ch)) = iter.next() {
854        if ch != ':' {
855            continue;
856        }
857
858        let Some(&(next_idx, next_ch)) = iter.peek() else {
859            continue;
860        };
861
862        if !is_ident_start(next_ch) {
863            continue;
864        }
865
866        if text[..idx].ends_with(':') {
867            continue;
868        }
869
870        if last < idx {
871            parts.push(TextPart::Lit(text[last..idx].to_string()));
872        }
873
874        iter.next();
875
876        let mut name = String::new();
877        name.push(next_ch);
878        let mut end = next_idx + next_ch.len_utf8();
879
880        while let Some(&(j, c)) = iter.peek() {
881            if is_ident_continue(c) {
882                name.push(c);
883                end = j + c.len_utf8();
884                iter.next();
885            } else {
886                break;
887            }
888        }
889
890        let mut is_list = false;
891        if text[end..].starts_with("[]") {
892            is_list = true;
893            end += 2;
894        }
895
896        parts.push(TextPart::Param { name, is_list });
897        last = end;
898    }
899
900    if last < text.len() {
901        parts.push(TextPart::Lit(text[last..].to_string()));
902    }
903
904    parts
905}
906
907fn render_validator_text(
908    text: &str,
909    use_dollar_params: bool,
910    param_offset: &mut usize,
911    list_count: usize,
912) -> (String, Vec<(String, bool)>) {
913    let mut out_sql = String::new();
914    let mut occurrences = Vec::new();
915
916    for part in parse_text_parts(text) {
917        match part {
918            TextPart::Lit(lit) => out_sql.push_str(&lit),
919            TextPart::Param { name, is_list } => {
920                if is_list && list_count > 1 {
921                    let slots: Vec<String> = if use_dollar_params {
922                        (0..list_count)
923                            .map(|i| format!("${}", *param_offset + i + 1))
924                            .collect()
925                    } else {
926                        (0..list_count).map(|_| "?".to_string()).collect()
927                    };
928                    if use_dollar_params {
929                        *param_offset += list_count;
930                    }
931                    out_sql.push_str(&slots.join(", "));
932                } else if use_dollar_params {
933                    *param_offset += 1;
934                    write!(out_sql, "${}", *param_offset).unwrap();
935                } else {
936                    out_sql.push('?');
937                }
938                occurrences.push((name, is_list));
939            }
940        }
941    }
942
943    (out_sql, occurrences)
944}
945
946fn strip_expr(expr: &Expr) -> &Expr {
947    match expr {
948        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
949        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
950        Expr::Block(ExprBlock { block, .. }) => {
951            if block.stmts.len() != 1 {
952                return expr;
953            }
954            match &block.stmts[0] {
955                Stmt::Expr(inner, None) => strip_expr(inner),
956                _ => expr,
957            }
958        }
959        _ => expr,
960    }
961}
962
963fn extract_lit_str(expr: &Expr) -> Option<String> {
964    match strip_expr(expr) {
965        Expr::Lit(ExprLit {
966            lit: Lit::Str(lit), ..
967        }) => Some(lit.value()),
968        _ => None,
969    }
970}
971
972// =============================================================================
973// Preprocessing: {>key} compile-time result flags
974// =============================================================================
975
976fn result_flag_ident(name: &str) -> syn::Ident {
977    format_ident!("__enhanced_result_flag_{}", name)
978}
979
980/// Replaces `{>key}` tokens inside braced groups with `__enhanced_result_flag_key`
981/// identifiers. This is a preprocessing step so that the rest of the parser sees
982/// plain identifiers instead of braced groups it does not understand.
983fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
984    fn walk(stream: TokenStream2) -> TokenStream2 {
985        let mut out = TokenStream2::new();
986        let iter = stream.into_iter().peekable();
987
988        for token in iter {
989            match token {
990                TokenTree::Group(group) => {
991                    if group.delimiter() == Delimiter::Brace {
992                        let mut inner = group.stream().into_iter();
993                        let first = inner.next();
994                        let second = inner.next();
995                        let third = inner.next();
996
997                        if let (
998                            Some(TokenTree::Punct(p)),
999                            Some(TokenTree::Ident(name_ident)),
1000                            None,
1001                        ) = (first, second, third)
1002                        {
1003                            if p.as_char() == '>' {
1004                                let ident = result_flag_ident(&name_ident.to_string());
1005                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1006                                continue;
1007                            }
1008                        }
1009                    }
1010
1011                    let new_inner = walk(group.stream());
1012                    let mut new_group = Group::new(group.delimiter(), new_inner);
1013                    new_group.set_span(group.span());
1014                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1015                }
1016                other => out.extend(std::iter::once(other)),
1017            }
1018        }
1019
1020        out
1021    }
1022
1023    walk(input)
1024}
1025
1026fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1027    keys.iter()
1028        .map(|key| {
1029            let ident = result_flag_ident(key);
1030            let enabled = Some(key.as_str()) == active_key;
1031            quote! { let #ident: bool = #enabled; }
1032        })
1033        .collect()
1034}
1035
1036fn transpose_section_case_matrix(
1037    case_matrix: Vec<Vec<SectionFragment>>,
1038    width: usize,
1039) -> Result<Vec<Vec<SectionFragment>>, String> {
1040    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1041
1042    for row in case_matrix {
1043        if row.len() != width {
1044            return Err(
1045                "sql_forge!: grouped sections must return one item per section".to_string(),
1046            );
1047        }
1048        for (section_idx, fragment) in row.into_iter().enumerate() {
1049            per_section[section_idx].push(fragment);
1050        }
1051    }
1052
1053    Ok(per_section)
1054}
1055
1056fn collect_section_case_matrix(
1057    value: SectionValue,
1058    width: usize,
1059    active_key: Option<&str>,
1060) -> Result<Vec<Vec<SectionFragment>>, String> {
1061    match value {
1062        SectionValue::Single(fragment) => {
1063            if width != 1 {
1064                return Err(
1065                    "sql_forge!: grouped sections must return one item per section".to_string(),
1066                );
1067            }
1068            Ok(vec![vec![fragment]])
1069        }
1070        SectionValue::Grouped(values) => {
1071            if values.len() != width {
1072                return Err(
1073                    "sql_forge!: grouped sections must return one item per section".to_string(),
1074                );
1075            }
1076
1077            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1078            let mut nmax = 1usize;
1079
1080            for value in values {
1081                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1082                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1083                for mut row in item_matrix {
1084                    let fragment = row.pop().ok_or_else(|| {
1085                        "sql_forge!: grouped sections must return one item per section".to_string()
1086                    })?;
1087                    if !row.is_empty() {
1088                        return Err(
1089                            "sql_forge!: grouped sections must return one item per section"
1090                                .to_string(),
1091                        );
1092                    }
1093                    item_variants.push(fragment);
1094                }
1095                if item_variants.is_empty() {
1096                    return Err("sql_forge!: section match must have at least one arm".to_string());
1097                }
1098                nmax = nmax.max(item_variants.len());
1099                variants_by_section.push(item_variants);
1100            }
1101
1102            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1103            for case_idx in 0..nmax {
1104                let mut row = Vec::<SectionFragment>::with_capacity(width);
1105                for variants in &variants_by_section {
1106                    row.push(variants[case_idx % variants.len()].clone());
1107                }
1108                case_matrix.push(row);
1109            }
1110
1111            Ok(case_matrix)
1112        }
1113        SectionValue::Match { expr, arms } => {
1114            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1115
1116            if let Some(key) = expr_result_flag_key(&expr) {
1117                let target = active_key == Some(key.as_str());
1118                for arm in arms {
1119                    if arm.guard.is_none() {
1120                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1121                            continue;
1122                        }
1123                    }
1124                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1125                    wrap_section_case_matrix_for_match_arm(
1126                        &mut arm_cases,
1127                        &expr,
1128                        &arm.pat,
1129                        arm.guard.as_ref(),
1130                    );
1131                    case_matrix.extend(arm_cases);
1132                }
1133            } else {
1134                for arm in arms {
1135                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1136                    wrap_section_case_matrix_for_match_arm(
1137                        &mut arm_cases,
1138                        &expr,
1139                        &arm.pat,
1140                        arm.guard.as_ref(),
1141                    );
1142                    case_matrix.extend(arm_cases);
1143                }
1144            }
1145
1146            if case_matrix.is_empty() {
1147                return Err("sql_forge!: section match must have at least one arm".to_string());
1148            }
1149
1150            Ok(case_matrix)
1151        }
1152    }
1153}
1154
1155/// This function rewrites a section-local validator expression so it stays inside the same
1156/// match arm scope that originally introduced it
1157fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1158    let match_expr = match_expr.clone();
1159    let pat = pat.clone();
1160    let pattern_binds_values = match &pat {
1161        Pat::Ident(_) => true,
1162        Pat::Or(pat_or) => pat_or
1163            .cases
1164            .iter()
1165            .any(|case| matches!(case, Pat::Ident(_))),
1166        Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1167        Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1168        Pat::Slice(pat_slice) => pat_slice
1169            .elems
1170            .iter()
1171            .any(|elem| matches!(elem, Pat::Ident(_))),
1172        Pat::Struct(pat_struct) => pat_struct
1173            .fields
1174            .iter()
1175            .any(|field| matches!(*field.pat, Pat::Ident(_))),
1176        Pat::Tuple(pat_tuple) => pat_tuple
1177            .elems
1178            .iter()
1179            .any(|elem| matches!(elem, Pat::Ident(_))),
1180        Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1181            .elems
1182            .iter()
1183            .any(|elem| matches!(elem, Pat::Ident(_))),
1184        Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1185        _ => false,
1186    };
1187
1188    if pattern_binds_values {
1189        let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1190            .into_iter()
1191            .map(|ident| quote! { let _ = &#ident; })
1192            .collect();
1193        if let Some(guard) = guard.cloned() {
1194            parse_quote! {
1195                match &(#match_expr) {
1196                    #pat if #guard => { #( #pat_refs )* #expr },
1197                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1198                }
1199            }
1200        } else {
1201            parse_quote! {
1202                match &(#match_expr) {
1203                    #pat => { #( #pat_refs )* #expr },
1204                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1205                }
1206            }
1207        }
1208    } else if let Some(guard) = guard.cloned() {
1209        parse_quote! {
1210            match &(#match_expr) {
1211                #pat if #guard => { &(#expr) },
1212                _ => unreachable!("sql_forge!: validator arm mismatch"),
1213            }
1214        }
1215    } else {
1216        parse_quote! {
1217            match &(#match_expr) {
1218                #pat => { &(#expr) },
1219                _ => unreachable!("sql_forge!: validator arm mismatch"),
1220            }
1221        }
1222    }
1223}
1224
1225fn wrap_params_source_for_match_arm(
1226    params: &mut ParamsSource,
1227    match_expr: &Expr,
1228    pat: &Pat,
1229    guard: Option<&Expr>,
1230) {
1231    match params {
1232        ParamsSource::None => {}
1233        ParamsSource::Map(entries) => {
1234            for entry in entries {
1235                entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1236            }
1237        }
1238        ParamsSource::Struct(expr) => {
1239            **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1240        }
1241    }
1242}
1243
1244fn wrap_section_case_matrix_for_match_arm(
1245    case_matrix: &mut [Vec<SectionFragment>],
1246    match_expr: &Expr,
1247    pat: &Pat,
1248    guard: Option<&Expr>,
1249) {
1250    for row in case_matrix {
1251        for fragment in row {
1252            wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1253        }
1254    }
1255}
1256
1257// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1258// =============================================================================
1259// Section variant collection
1260// =============================================================================
1261
1262/// Collects all possible `SectionFragment` values per section index. Each
1263/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1264/// listing every variant that section can take across all match arms.
1265/// Used for full validation (cycling strategy over all variants).
1266fn collect_section_variants(
1267    value: SectionValue,
1268    width: usize,
1269) -> Result<Vec<Vec<SectionFragment>>, String> {
1270    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1271}
1272
1273fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1274    match strip_expr(expr) {
1275        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1276            let name = path.path.segments[0].ident.to_string();
1277            name.strip_prefix("__enhanced_result_flag_")
1278                .map(|v| v.to_string())
1279        }
1280        _ => None,
1281    }
1282}
1283
1284fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1285    match pat {
1286        Pat::Lit(expr_lit) => match &expr_lit.lit {
1287            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1288            _ => None,
1289        },
1290        Pat::Wild(_) => Some(true),
1291        _ => None,
1292    }
1293}
1294
1295/// Like `collect_section_variants`, but filters `match` arms by the active
1296/// result key when the match expression is a `{>key}` flag. When building the
1297/// query for a specific key, only the matching arm (true/false) is included;
1298/// arms with guards or non-flag expressions include all variants as usual.
1299fn collect_section_variants_for_result(
1300    value: SectionValue,
1301    width: usize,
1302    active_key: Option<&str>,
1303) -> Result<Vec<Vec<SectionFragment>>, String> {
1304    transpose_section_case_matrix(
1305        collect_section_case_matrix(value, width, active_key)?,
1306        width,
1307    )
1308}
1309
1310// =============================================================================
1311// Parameter binding generation
1312// =============================================================================
1313
1314/// Generates `let` bindings for all parameters used in the SQL or section
1315/// fragments. Returns a map of param-name → local ident for later reference,
1316/// and a list of `let` statements.
1317fn build_param_bindings(
1318    params: &ParamsSource,
1319    used_param_names: &[String],
1320    prefix: &str,
1321    for_validator: bool,
1322    enforce_usage_check: bool,
1323) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1324    let mut declared_params = HashMap::<String, syn::Ident>::new();
1325    let mut bindings = Vec::<TokenStream2>::new();
1326
1327    match params {
1328        ParamsSource::None => {}
1329        ParamsSource::Map(entries) => {
1330            for entry in entries {
1331                let key = entry.name.to_string();
1332                if declared_params.contains_key(&key) {
1333                    return Err(syn::Error::new(
1334                        entry.name.span(),
1335                        "sql_forge!: duplicated parameter mapping",
1336                    )
1337                    .to_compile_error()
1338                    .into());
1339                }
1340                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1341                    return Err(syn::Error::new(
1342                        entry.name.span(),
1343                        format!(
1344                            "sql_forge!: parameter :{} is unused in the SQL template",
1345                            key,
1346                        ),
1347                    )
1348                    .to_compile_error()
1349                    .into());
1350                }
1351                let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1352                let expr = &entry.expr;
1353                if for_validator {
1354                    bindings.push(quote! {
1355                        let #local_ident = &(#expr);
1356                    });
1357                } else {
1358                    bindings.push(quote! {
1359                        let #local_ident = #expr;
1360                    });
1361                }
1362                declared_params.insert(key, local_ident);
1363            }
1364        }
1365        ParamsSource::Struct(expr) => {
1366            let source_ident = format_ident!("__enhanced_source_{}", prefix);
1367            bindings.push(quote! {
1368                let #source_ident = &(#expr);
1369            });
1370            for name in used_param_names {
1371                let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1372                let field_ident = format_ident!("{}", name);
1373                if for_validator {
1374                    bindings.push(quote! {
1375                        let #local_ident = &#source_ident.#field_ident;
1376                    });
1377                } else {
1378                    bindings.push(quote! {
1379                        let #local_ident = #source_ident.#field_ident;
1380                    });
1381                }
1382                declared_params.insert(name.to_string(), local_ident);
1383            }
1384        }
1385    }
1386
1387    Ok((declared_params, bindings))
1388}
1389
1390struct ValidatorRenderContext<'a> {
1391    local_params: &'a HashMap<String, syn::Ident>,
1392    top_level_params: &'a HashMap<String, syn::Ident>,
1393    allow_top_level_fallback: bool,
1394    use_dollar_params: bool,
1395    sql_span: Span,
1396    list_count: usize,
1397}
1398
1399/// Builds the placeholders SQL string and argument list for the compile-time
1400/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1401/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1402/// corresponding value expression is collected into the args list.
1403fn render_validator_args(
1404    sql: &str,
1405    param_offset: &mut usize,
1406    arg_index: &mut usize,
1407    context: &ValidatorRenderContext<'_>,
1408) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1409    let (rendered_sql, occurrences) = render_validator_text(
1410        sql,
1411        context.use_dollar_params,
1412        param_offset,
1413        context.list_count,
1414    );
1415
1416    let mut setup = Vec::<TokenStream2>::new();
1417    let mut args = Vec::<TokenStream2>::new();
1418
1419    for (name, is_list) in occurrences {
1420        let local_ident = if context.allow_top_level_fallback {
1421            context
1422                .local_params
1423                .get(&name)
1424                .or_else(|| context.top_level_params.get(&name))
1425        } else {
1426            context.local_params.get(&name)
1427        };
1428
1429        let Some(local_ident) = local_ident else {
1430            return Err(syn::Error::new(
1431                context.sql_span,
1432                format!("sql_forge!: parameter :{} has no mapping", name),
1433            )
1434            .to_compile_error()
1435            .into());
1436        };
1437
1438        if is_list {
1439            for _ in 0..context.list_count {
1440                let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1441                *arg_index += 1;
1442                if context.use_dollar_params {
1443                    setup.push(quote! {
1444                        let #value_ident = sql_forge::sql_forge_validator_value(
1445                            (#local_ident)
1446                                .as_slice()
1447                                .first()
1448                                .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1449                        );
1450                    });
1451                } else {
1452                    setup.push(quote! {
1453                        let #value_ident = (#local_ident)
1454                            .as_slice()
1455                            .first()
1456                            .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1457                    });
1458                }
1459                args.push(quote! { #value_ident });
1460            }
1461        } else {
1462            let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1463            *arg_index += 1;
1464            if context.use_dollar_params {
1465                setup.push(quote! {
1466                    let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1467                });
1468            } else {
1469                setup.push(quote! {
1470                    let #value_ident = #local_ident;
1471                });
1472            }
1473            args.push(quote! { #value_ident });
1474        }
1475    }
1476
1477    Ok((rendered_sql, setup, args))
1478}
1479
1480// =============================================================================
1481// Runtime code generation (QueryBuilder-based)
1482// =============================================================================
1483
1484/// Generates the `push()` / `push_bind()` calls for a single section fragment
1485/// at runtime using `sqlx::QueryBuilder`.
1486fn render_runtime_fragment(
1487    fragment: &SectionFragment,
1488    local_params: &HashMap<String, syn::Ident>,
1489) -> Result<TokenStream2, TokenStream> {
1490    let mut steps = Vec::<TokenStream2>::new();
1491
1492    for part in parse_text_parts(&fragment.sql) {
1493        match part {
1494            TextPart::Lit(lit) => {
1495                let lit_str = LitStr::new(&lit, fragment.span);
1496                steps.push(quote! { __builder.push(#lit_str); });
1497            }
1498            TextPart::Param { name, is_list } => {
1499                let Some(local_ident) = local_params.get(&name) else {
1500                    return Err(syn::Error::new(
1501                        fragment.span,
1502                        format!("sql_forge!: parameter :{} has no mapping", name),
1503                    )
1504                    .to_compile_error()
1505                    .into());
1506                };
1507
1508                if is_list {
1509                    steps.push(quote! {
1510                        let __enhanced_values = #local_ident;
1511                        let mut __separated = __builder.separated(", ");
1512                        for __value in __enhanced_values {
1513                            __separated.push_bind(__value);
1514                        }
1515                    });
1516                } else {
1517                    steps.push(quote! {
1518                        __builder.push_bind(#local_ident);
1519                    });
1520                }
1521            }
1522        }
1523    }
1524
1525    Ok(quote! { #( #steps )* })
1526}
1527
1528fn is_pat_binding(ident: &Ident) -> bool {
1529    let name = ident.to_string();
1530    !name.is_empty()
1531        && name
1532            .chars()
1533            .next()
1534            .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1535}
1536
1537fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1538    let mut names = Vec::new();
1539    fn walk(p: &Pat, names: &mut Vec<Ident>) {
1540        match p {
1541            Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1542            Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1543            Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1544            Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1545            Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1546            Pat::Paren(pp) => walk(&pp.pat, names),
1547            Pat::Reference(pr) => walk(&pr.pat, names),
1548            Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1549            Pat::Type(pt) => walk(&pt.pat, names),
1550            _ => {}
1551        }
1552    }
1553    walk(pat, &mut names);
1554    names
1555}
1556
1557fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1558    match value {
1559        SectionValue::Single(f) => {
1560            if collect_used_param_names_in_sql(&f.sql)
1561                .iter()
1562                .any(|n| n == name)
1563            {
1564                return true;
1565            }
1566            if let ParamsSource::Map(entries) = &f.params {
1567                for e in entries {
1568                    let expr = &e.expr;
1569                    let expr_str = quote! { #expr }.to_string();
1570                    if expr_str.trim() == name {
1571                        return true;
1572                    }
1573                }
1574            }
1575            false
1576        }
1577        SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1578        SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1579            let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1580                .into_iter()
1581                .map(|i| i.to_string())
1582                .collect();
1583            if pat_vars.contains(name) {
1584                false
1585            } else {
1586                section_value_refers_to(&arm.value, name)
1587            }
1588        }),
1589    }
1590}
1591
1592fn build_section_runtime_action(
1593    value: &SectionValue,
1594    section_idx: usize,
1595    prefix: &str,
1596) -> Result<TokenStream2, TokenStream> {
1597    match value {
1598        SectionValue::Single(fragment) => {
1599            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1600            let (local_params, bindings) =
1601                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1602            let body = render_runtime_fragment(fragment, &local_params)?;
1603            Ok(quote! {{ #( #bindings )* #body }})
1604        }
1605        SectionValue::Grouped(fragments) => build_section_runtime_action(
1606            &fragments[section_idx],
1607            0,
1608            &format!("{}_grouped_{}", prefix, section_idx),
1609        ),
1610        SectionValue::Match { expr, arms } => {
1611            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1612                .iter()
1613                .enumerate()
1614                .map(|(arm_idx, arm)| {
1615                    let pat = &arm.pat;
1616                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1617                    let body = build_section_runtime_action(
1618                        &arm.value,
1619                        section_idx,
1620                        &format!("{}_{}", prefix, arm_idx),
1621                    )?;
1622                    let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1623                        .into_iter()
1624                        .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1625                        .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1626                        .collect();
1627                    Ok::<TokenStream2, TokenStream>(quote! {
1628                        #pat #guard_tokens => {
1629                            #( #noop_refs )*
1630                            #body
1631                        }
1632                    })
1633                })
1634                .collect();
1635            let arm_tokens = arm_tokens?;
1636            Ok(quote! {
1637                match #expr {
1638                    #( #arm_tokens ),*
1639                }
1640            })
1641        }
1642    }
1643}
1644
1645fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1646    let mut names = Vec::new();
1647    let mut seen = HashSet::<String>::new();
1648
1649    for segment in segments {
1650        match segment {
1651            Segment::Text(text) => {
1652                for name in collect_used_param_names_in_sql(text) {
1653                    if seen.insert(name.clone()) {
1654                        names.push(name);
1655                    }
1656                }
1657            }
1658            Segment::Batch { parts } => {
1659                for part in parts {
1660                    if let TextPart::Param { name, .. } = part {
1661                        if seen.insert(name.clone()) {
1662                            names.push(name.clone());
1663                        }
1664                    }
1665                }
1666            }
1667            _ => {}
1668        }
1669    }
1670
1671    names
1672}
1673
1674fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1675    let mut names = Vec::new();
1676    let mut seen = HashSet::<String>::new();
1677    for part in parse_text_parts(sql) {
1678        if let TextPart::Param { name, .. } = part {
1679            if seen.insert(name.to_string()) {
1680                names.push(name);
1681            }
1682        }
1683    }
1684    names
1685}
1686
1687/// Builds a parameterized SQL query with compile-time type-checking and a
1688/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1689///
1690/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1691/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1692///
1693/// # Syntax
1694///
1695/// ```text
1696/// sql_forge!(
1697///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1698///     [Model,]     // optional result spec
1699///     SQL,         // string literal
1700///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1701///     [(sections),]// optional: ( #name = ..., ... )
1702///     [..batch]    // optional: batch source expression used by {( ... )}
1703/// )
1704/// ```
1705///
1706/// `Model` has three forms:
1707/// - omitted: execute-only query; only `.execute(...)` is available
1708/// - `Type` or `scalar Type`: a single result query
1709/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1710///
1711/// The trailing parameter source, section map, and batch source are optional.
1712/// The batch source may appear alongside the others as a single `..expr` argument.
1713///
1714/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1715/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1716/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1717/// The env var takes priority over Cargo.toml metadata.
1718///
1719/// # Parameters
1720///
1721/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1722/// is replaced by a `push_bind` call; at compile time it becomes a
1723/// database-specific placeholder: `?` for MySQL and SQLite, and `$1`, `$2`, ...
1724/// for Postgres.
1725///
1726/// **Inline map** – bind individual expressions:
1727/// ```rust,ignore
1728/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1729/// ```
1730///
1731/// **Struct source** – field names are matched to `:name` placeholders automatically:
1732/// ```rust,ignore
1733/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1734/// ```
1735///
1736/// # Sections (`{#name}`)
1737///
1738/// Sections are runtime SQL slots; each section's variants are validated at
1739/// compile time via `query_as!` / `query_scalar!`, though not every combination
1740/// of variants across sections is checked. The section map is a second parenthesised
1741/// argument starting with `#`:
1742///
1743/// ```rust,ignore
1744/// sql_forge!(
1745///     User,
1746///     "SELECT * FROM users {#join_org}",
1747///     (
1748///         #join_org = match include_org {
1749///             true  => " JOIN organisations o ON o.id = users.org_id ",
1750///             false => "",
1751///         }
1752///     )
1753/// )
1754/// ```
1755///
1756/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1757///
1758/// ```rust,ignore
1759/// sql_forge!(
1760///     User,
1761///     "SELECT * FROM users {#filter}",
1762///     (
1763///         #filter = (
1764///             " WHERE id <= :max_id AND status = :status ",
1765///             ( :max_id = max_id, :status = "active" ),
1766///         )
1767///     )
1768/// )
1769/// ```
1770///
1771/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1772/// returning a tuple of the same width:
1773///
1774/// ```rust,ignore
1775/// sql_forge!(
1776///     User,
1777///     "SELECT * FROM users {#join_org} {#filter_org}",
1778///     (
1779///         #(join_org, filter_org) = match include_org {
1780///             true  => (
1781///                 " JOIN organisations o ON o.id = users.org_id ",
1782///                 (
1783///                     " AND o.active = :active ",
1784///                     ( :active = true ),
1785///                 ),
1786///             ),
1787///             false => ("", ""),
1788///         }
1789///     )
1790/// )
1791/// ```
1792///
1793/// Grouped section items may themselves use nested `match` expressions. Those
1794/// nested matches use smart cycling within the arm rather than a cartesian
1795/// product. For example, if one grouped arm returns a fixed first item plus two
1796/// nested binary matches for the second and third items, that arm contributes
1797/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1798/// `(1, 0)`, `(1, 1)` combinations.
1799///
1800/// # `IN (...)` with list parameters
1801///
1802/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1803/// values:
1804///
1805/// ```rust,ignore
1806/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1807/// ```
1808///
1809/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1810/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1811///
1812/// ```rust,ignore
1813/// sql_forge!(
1814///     User,
1815///     "SELECT id, name FROM users WHERE {#filter}",
1816///     (
1817///         #filter = match ids.is_empty() {
1818///             true  => "1 = 0",
1819///             false => ("id IN (:ids[])", ( :ids = ids )),
1820///         }
1821///     )
1822/// )
1823/// ```
1824///
1825/// # Batch inserts (`{( ... )}`)
1826///
1827/// A batch section `{( ... )}` repeats its content for each item in an iterable
1828/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1829/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1830/// sections.
1831///
1832/// ## Struct batch
1833///
1834/// ```rust,ignore
1835/// struct BatchItem { name: String, price: i64 }
1836///
1837/// let items = vec![
1838///     BatchItem { name: "A".into(), price: 100 },
1839///     BatchItem { name: "B".into(), price: 200 },
1840/// ];
1841///
1842/// sql_forge!(
1843///     "INSERT INTO products (name, price, stock, category)
1844///      VALUES {(:name, :price, 10, 'Batch')}",
1845///     ..items
1846/// )
1847/// .execute(&pool)
1848/// .await?;
1849/// ```
1850///
1851/// For compile-time checking, the validator expands the batch to 3 fake copies
1852/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`).
1853/// At runtime the iterable drives the actual number of rows.
1854///
1855/// # Scalar output
1856///
1857/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
1858/// `query_scalar!` for validation and `build_query_scalar` for execution.
1859///
1860/// # Multiple results
1861///
1862/// A result map produces a `SqlForgeQueryGroup` with one query per key.
1863/// Each key can be a struct or a primitive (used as a scalar):
1864///
1865/// ```rust,ignore
1866/// sql_forge!(
1867///     (
1868///         >count = i64,
1869///         >items = Item,
1870///     ),
1871///     "SELECT {#fields} FROM items WHERE category_id = :cat",
1872///     ( :cat = category_id ),
1873///     (
1874///         #fields = match {>count} {           // {>key} is true when building
1875///             true  => "COUNT(*) AS total",    // the query for that model/result
1876///             false => "id, name, price",      // key and false otherwise
1877///         }
1878///     )
1879/// )
1880/// ```
1881///
1882/// The generated struct has one field per key (`group.count`, `group.items`),
1883/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
1884/// executor method (`fetch_one`, `fetch_all`, etc.).
1885///
1886/// # Execute-only (no model)
1887///
1888/// When the model type is omitted, the macro produces a value implementing
1889/// `SqlForgeQueryExecute`. Only `.execute(executor)`
1890/// is available and there is no return type to deserialize into. This is useful
1891/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
1892///
1893/// ```rust,ignore
1894/// sql_forge!(
1895///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
1896///     ( :id = 42i64 ),
1897/// )
1898/// .execute(&pool)
1899/// .await?;
1900/// ```
1901///
1902/// Sections and struct parameter sources work the same way as in model-backed queries:
1903///
1904/// ```rust,ignore
1905/// sql_forge!(
1906///     "UPDATE products SET price = :new_price {#filter}",
1907///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
1908/// )
1909/// .execute(&pool)
1910/// .await?;
1911/// ```
1912#[proc_macro]
1913#[allow(clippy::too_many_lines)]
1914pub fn sql_forge(input: TokenStream) -> TokenStream {
1915    // ---- Phase 1: Parse the macro input into structured data ----
1916    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1917    let SqlForgeInput {
1918        db,
1919        result,
1920        force_scalar,
1921        sql,
1922        params,
1923        sections,
1924        batch,
1925    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1926        Ok(v) => v,
1927        Err(err) => return err.to_compile_error().into(),
1928    };
1929
1930    // ---- Phase 2: Resolve database type (from macro arg or Cargo.toml) ----
1931    let db = match db {
1932        Some(db) => db,
1933        None => match resolve_db_from_env() {
1934            Ok(db) => db,
1935            Err(msg) => {
1936                return syn::Error::new(Span::call_site(), msg)
1937                    .to_compile_error()
1938                    .into();
1939            }
1940        },
1941    };
1942
1943    let use_dollar_params = uses_dollar_params(&db);
1944    let is_sqlite = if let syn::Type::Path(type_path) = &db {
1945        type_path
1946            .path
1947            .segments
1948            .last()
1949            .is_some_and(|s| s.ident == "Sqlite")
1950    } else {
1951        false
1952    };
1953    let list_count: usize = if is_sqlite { 1 } else { 3 };
1954
1955    // ---- Phase 3: Build result case definitions ----
1956    // Each result case is (optional_key, model_type, optional_scalar_type).
1957    // Scalar type is set for primitives and `scalar`-marked types.
1958    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1959        ResultSpec::None => {
1960            vec![(None, None, None)]
1961        }
1962        ResultSpec::Single(ref model) => {
1963            let model_ty = (**model).clone();
1964            let scalar = if force_scalar {
1965                Some(model_ty.clone())
1966            } else {
1967                scalar_output_type(model.as_ref()).cloned()
1968            };
1969            vec![(None, Some(model_ty), scalar)]
1970        }
1971        ResultSpec::Group(ref cases) => {
1972            if force_scalar {
1973                return syn::Error::new(
1974                    Span::call_site(),
1975                    "sql_forge!: scalar mode is not supported for grouped result maps",
1976                )
1977                .to_compile_error()
1978                .into();
1979            }
1980
1981            let mut out = Vec::new();
1982            let mut seen = HashSet::new();
1983            for case in cases {
1984                let key = case.name.to_string();
1985                if !seen.insert(key.clone()) {
1986                    return syn::Error::new(
1987                        case.name.span(),
1988                        "sql_forge!: duplicated key in result map",
1989                    )
1990                    .to_compile_error()
1991                    .into();
1992                }
1993
1994                let model = case.model.clone();
1995                let scalar = if case.force_scalar {
1996                    Some(model.clone())
1997                } else {
1998                    scalar_output_type(&case.model).cloned()
1999                };
2000                out.push((Some(key), Some(model), scalar));
2001            }
2002            out
2003        }
2004    };
2005    let group_result_keys: Vec<String> = result_cases
2006        .iter()
2007        .filter_map(|(key, _, _)| key.as_ref().cloned())
2008        .collect();
2009    let is_grouped_result = !group_result_keys.is_empty();
2010    let sql_span = sql.span();
2011
2012    // ---- Phase 4: Parse SQL into segments (text + {#section} slots) ----
2013    let segments = match sql.into_segments() {
2014        Ok(segments) => segments,
2015        Err(msg) => {
2016            return syn::Error::new(sql_span, msg).to_compile_error().into();
2017        }
2018    };
2019
2020    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2021    match (&batch, has_batch_segment) {
2022        (None, true) => {
2023            return syn::Error::new(
2024                sql_span,
2025                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2026                 was provided"
2027            )
2028            .to_compile_error()
2029            .into();
2030        }
2031        (Some(_), false) => {
2032            return syn::Error::new(
2033                sql_span,
2034                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2035                 batch section",
2036            )
2037            .to_compile_error()
2038            .into();
2039        }
2040        _ => {}
2041    }
2042
2043    let used_param_names = collect_used_param_names(&segments);
2044
2045    // Batch-only params come from batch items, not the top-level params map.
2046    // They must be excluded from the usage check so that a param like :category
2047    // that appears only inside {( ... )} is flagged as unused when given in the
2048    // params map, as it would never be read from there at runtime.
2049    let batch_param_names: std::collections::HashSet<String> = segments
2050        .iter()
2051        .filter_map(|s| {
2052            if let Segment::Batch { parts } = s {
2053                Some(parts.iter().filter_map(|p| {
2054                    if let TextPart::Param { name, .. } = p {
2055                        Some(name.clone())
2056                    } else {
2057                        None
2058                    }
2059                }))
2060            } else {
2061                None
2062            }
2063        })
2064        .flatten()
2065        .collect();
2066    let top_level_used_names: Vec<String> = used_param_names
2067        .iter()
2068        .filter(|n| !batch_param_names.contains(*n))
2069        .cloned()
2070        .collect();
2071
2072    // ---- Phase 5: Build parameter bindings for the top-level params ----
2073    let (declared_params, validator_param_bindings) =
2074        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
2075            Ok(v) => v,
2076            Err(err) => return err,
2077        };
2078
2079    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2080
2081    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
2082    for assign in &sections {
2083        let SectionAssign { names, value } = assign;
2084
2085        // Build runtime actions first, while `value` is still available by reference.
2086        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2087        for (section_idx, name_ident) in names.iter().enumerate() {
2088            let name = name_ident.to_string();
2089            if runtime_section_actions.contains_key(&name) {
2090                return syn::Error::new(
2091                    name_ident.span(),
2092                    "sql_forge!: duplicated section mapping",
2093                )
2094                .to_compile_error()
2095                .into();
2096            }
2097            let action = match build_section_runtime_action(
2098                value,
2099                section_idx,
2100                &format!("section_{}", name),
2101            ) {
2102                Ok(action) => action,
2103                Err(err) => return err,
2104            };
2105            named_actions.push((name, action));
2106        }
2107
2108        // Consume `value` here so invalid grouped/nested section structures fail early.
2109        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2110            return syn::Error::new(names[0].span(), msg)
2111                .to_compile_error()
2112                .into();
2113        }
2114
2115        for (name, action) in named_actions {
2116            runtime_section_actions.insert(name, action);
2117        }
2118    }
2119
2120    let sql_section_names: std::collections::HashSet<&str> = segments
2121        .iter()
2122        .filter_map(|seg| {
2123            if let Segment::Section { name } = seg {
2124                Some(name.as_str())
2125            } else {
2126                None
2127            }
2128        })
2129        .collect();
2130    for name in runtime_section_actions.keys() {
2131        if !sql_section_names.contains(name.as_str()) {
2132            return syn::Error::new(
2133                sql_span,
2134                format!(
2135                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2136                    name, name,
2137                ),
2138            )
2139            .to_compile_error()
2140            .into();
2141        }
2142    }
2143
2144    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
2145    let mut generated_query_defs = Vec::<TokenStream2>::new();
2146    let mut generated_query_values = Vec::<TokenStream2>::new();
2147    let mut group_field_defs = Vec::<TokenStream2>::new();
2148    let mut group_method_defs = Vec::<TokenStream2>::new();
2149    let mut group_field_idents = Vec::<syn::Ident>::new();
2150    let mut group_field_tys = Vec::<TokenStream2>::new();
2151    let mut group_trait_impls = Vec::<TokenStream2>::new();
2152
2153    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2154
2155    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2156        let suffix = result_key.as_deref().unwrap_or("single");
2157        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2158        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2159
2160        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2161
2162        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2163        for assign in &sections {
2164            let SectionAssign { names, value } = assign;
2165            let variants_by_section = match collect_section_variants_for_result(
2166                value.clone(),
2167                names.len(),
2168                result_key.as_deref(),
2169            ) {
2170                Ok(v) => v,
2171                Err(msg) => {
2172                    return syn::Error::new(names[0].span(), msg)
2173                        .to_compile_error()
2174                        .into();
2175                }
2176            };
2177
2178            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2179                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2180            }
2181        }
2182
2183        let mut nmax = 1usize;
2184        for segment in &segments {
2185            if let Segment::Section { name } = segment {
2186                if let Some(variants) = section_variants_for_validation.get(name) {
2187                    if variants.is_empty() {
2188                        return syn::Error::new(
2189                            sql_span,
2190                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
2191                        )
2192                        .to_compile_error()
2193                        .into();
2194                    }
2195                    nmax = nmax.max(variants.len());
2196                } else {
2197                    return syn::Error::new(
2198                        sql_span,
2199                        format!("sql_forge!: section {{#{}}} has no mapping", name),
2200                    )
2201                    .to_compile_error()
2202                    .into();
2203                }
2204            }
2205        }
2206
2207        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2208        for case_idx in 0..nmax {
2209            let mut sql_case = String::new();
2210            let mut case_setup = Vec::<TokenStream2>::new();
2211            let mut case_args = Vec::<TokenStream2>::new();
2212            let mut param_offset = 0usize;
2213            let mut arg_index = 0usize;
2214            let empty_params = HashMap::<String, syn::Ident>::new();
2215            let root_validator_context = ValidatorRenderContext {
2216                local_params: &empty_params,
2217                top_level_params: &declared_params,
2218                allow_top_level_fallback: true,
2219                use_dollar_params,
2220                sql_span,
2221                list_count,
2222            };
2223
2224            for segment in &segments {
2225                match segment {
2226                    Segment::Text(text) => {
2227                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2228                            text,
2229                            &mut param_offset,
2230                            &mut arg_index,
2231                            &root_validator_context,
2232                        ) {
2233                            Ok(value) => value,
2234                            Err(err) => return err,
2235                        };
2236                        sql_case.push_str(&chunk_sql);
2237                        case_setup.extend(chunk_setup);
2238                        case_args.extend(chunk_args);
2239                    }
2240                    Segment::Section { name } => {
2241                        let Some(variants) = section_variants_for_validation.get(name) else {
2242                            return syn::Error::new(
2243                                sql_span,
2244                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2245                            )
2246                            .to_compile_error()
2247                            .into();
2248                        };
2249
2250                        let fragment = &variants[case_idx % variants.len()];
2251                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2252                        let (local_params, bindings) = match build_param_bindings(
2253                            &fragment.params,
2254                            &used_param_names,
2255                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2256                            true,
2257                            true,
2258                        ) {
2259                            Ok(value) => value,
2260                            Err(err) => return err,
2261                        };
2262                        let section_validator_context = ValidatorRenderContext {
2263                            local_params: &local_params,
2264                            top_level_params: &declared_params,
2265                            allow_top_level_fallback: false,
2266                            use_dollar_params,
2267                            sql_span: fragment.span,
2268                            list_count,
2269                        };
2270                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2271                            &fragment.sql,
2272                            &mut param_offset,
2273                            &mut arg_index,
2274                            &section_validator_context,
2275                        ) {
2276                            Ok(value) => value,
2277                            Err(err) => return err,
2278                        };
2279                        sql_case.push_str(&chunk_sql);
2280                        case_setup.extend(bindings);
2281                        case_setup.extend(chunk_setup);
2282                        case_args.extend(chunk_args);
2283                    }
2284                    Segment::Batch { parts } => {
2285                        let mut first = true;
2286                        for _ in 0..list_count {
2287                            let sep = if first { "" } else { ", " };
2288                            first = false;
2289                            sql_case.push_str(sep);
2290                            for tp in parts {
2291                                match tp {
2292                                    TextPart::Lit(lit) => sql_case.push_str(lit),
2293                                    TextPart::Param { name, .. } => {
2294                                        if let Some(batch_expr) = &batch {
2295                                            let field_ident = format_ident!("{}", name);
2296                                            if use_dollar_params {
2297                                                param_offset += 1;
2298                                                write!(sql_case, "${}", param_offset).unwrap();
2299                                            } else {
2300                                                sql_case.push('?');
2301                                            }
2302                                            case_args.push(quote! { #batch_expr[0].#field_ident });
2303                                        } else if use_dollar_params {
2304                                            param_offset += 1;
2305                                            write!(sql_case, "${}", param_offset).unwrap();
2306                                        } else {
2307                                            sql_case.push('?');
2308                                        }
2309                                    }
2310                                }
2311                            }
2312                        }
2313                    }
2314                }
2315            }
2316
2317            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2318        }
2319
2320        let mut validator_invocations = Vec::<TokenStream2>::new();
2321        for (sql_lit, case_setup, args) in &validator_cases {
2322            if model_opt.is_none() {
2323                if args.is_empty() {
2324                    validator_invocations.push(quote! {
2325                        {
2326                            #( #case_setup )*
2327                            let _ = sqlx::query_scalar!(
2328                                #sql_lit,
2329                            );
2330                        }
2331                    });
2332                } else {
2333                    validator_invocations.push(quote! {
2334                        {
2335                            #( #case_setup )*
2336                            let _ = sqlx::query_scalar!(
2337                                #sql_lit,
2338                                #( #args ),*
2339                            );
2340                        }
2341                    });
2342                }
2343            } else if let Some(scalar_ty) = scalar_model_ty {
2344                if args.is_empty() {
2345                    validator_invocations.push(quote! {
2346                        {
2347                            #( #case_setup )*
2348                            let _ = sqlx::query_scalar!(
2349                                #sql_lit,
2350                            );
2351                        }
2352                    });
2353                } else {
2354                    validator_invocations.push(quote! {
2355                        {
2356                            #( #case_setup )*
2357                            let _ = sqlx::query_scalar!(
2358                                #sql_lit,
2359                                #( #args ),*
2360                            );
2361                        }
2362                    });
2363                }
2364                let _ = scalar_ty;
2365            } else if args.is_empty() {
2366                validator_invocations.push(quote! {
2367                    {
2368                        #( #case_setup )*
2369                        let _ = sqlx::query_as!(
2370                            __EnhancedModel,
2371                            #sql_lit,
2372                        );
2373                    }
2374                });
2375            } else {
2376                validator_invocations.push(quote! {
2377                    {
2378                        #( #case_setup )*
2379                        let _ = sqlx::query_as!(
2380                            __EnhancedModel,
2381                            #sql_lit,
2382                            #( #args ),*
2383                        );
2384                    }
2385                });
2386            }
2387        }
2388
2389        let model_alias = if let Some(model) = model_opt {
2390            if scalar_model_ty.is_none() {
2391                quote! { type __EnhancedModel = #model; }
2392            } else {
2393                quote! {}
2394            }
2395        } else {
2396            quote! {}
2397        };
2398        grouped_validator_invocations.push(quote! {
2399            {
2400                #( #flag_bindings )*
2401                #model_alias
2402                #( #validator_invocations )*
2403            }
2404        });
2405
2406        let (runtime_declared_params, runtime_param_bindings) =
2407            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2408                Ok(v) => v,
2409                Err(err) => return err,
2410            };
2411
2412        let mut runtime_steps = Vec::<TokenStream2>::new();
2413        for (seg_idx, segment) in segments.iter().enumerate() {
2414            match segment {
2415                Segment::Text(text) => {
2416                    for part in parse_text_parts(text) {
2417                        match part {
2418                            TextPart::Lit(lit) => {
2419                                let lit = sanitize_runtime_sql_text(&lit);
2420                                let lit_str = LitStr::new(&lit, sql_span);
2421                                runtime_steps.push(quote! {
2422                                    __builder.push(#lit_str);
2423                                });
2424                            }
2425                            TextPart::Param { name, is_list } => {
2426                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2427                                    return syn::Error::new(
2428                                        sql_span,
2429                                        format!("sql_forge!: parameter :{} has no mapping", name),
2430                                    )
2431                                    .to_compile_error()
2432                                    .into();
2433                                };
2434
2435                                if is_list {
2436                                    runtime_steps.push(quote! {
2437                                        let __enhanced_values = #local_ident;
2438                                        let mut __separated = __builder.separated(", ");
2439                                        for __value in __enhanced_values {
2440                                            __separated.push_bind(__value);
2441                                        }
2442                                    });
2443                                } else {
2444                                    runtime_steps.push(quote! {
2445                                        __builder.push_bind(#local_ident);
2446                                    });
2447                                }
2448                            }
2449                        }
2450                    }
2451                }
2452                Segment::Section { name } => {
2453                    let Some(section_action) = runtime_section_actions.get(name) else {
2454                        let _ = seg_idx;
2455                        return syn::Error::new(
2456                            sql_span,
2457                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2458                        )
2459                        .to_compile_error()
2460                        .into();
2461                    };
2462                    runtime_steps.push(quote! {
2463                        #section_action
2464                    });
2465                }
2466                Segment::Batch { parts } => {
2467                    if let Some(batch_expr) = &batch {
2468                        let mut body = Vec::<TokenStream2>::new();
2469                        for part in parts {
2470                            match part {
2471                                TextPart::Lit(lit) => {
2472                                    let lit_str = LitStr::new(lit, sql_span);
2473                                    body.push(quote! {
2474                                        __builder.push(#lit_str);
2475                                    });
2476                                }
2477                                TextPart::Param { name, .. } => {
2478                                    let field_ident = format_ident!("{}", name);
2479                                    body.push(quote! {
2480                                        __builder.push_bind(__item.#field_ident);
2481                                    });
2482                                }
2483                            }
2484                        }
2485                        runtime_steps.push(quote! {
2486                            {
2487                                let mut __first = true;
2488                                for __item in #batch_expr {
2489                                    if !__first {
2490                                        __builder.push(", ");
2491                                    }
2492                                    __first = false;
2493                                    #( #body )*
2494                                }
2495                            }
2496                        });
2497                    }
2498                }
2499            }
2500        }
2501
2502        let exec_methods = if model_opt.is_none() {
2503            quote! {
2504                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2505                where
2506                    E: sqlx::Executor<'e, Database = #db>,
2507                {
2508                    self.inner.build().execute(executor).await
2509                }
2510            }
2511        } else if let Some(scalar_ty) = scalar_model_ty {
2512            quote! {
2513                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2514                where
2515                    E: sqlx::Executor<'e, Database = #db>,
2516                {
2517                    self.inner
2518                        .build_query_scalar::<#scalar_ty>()
2519                        .fetch_all(executor)
2520                        .await
2521                }
2522
2523                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2524                where
2525                    E: sqlx::Executor<'e, Database = #db>,
2526                {
2527                    self.inner
2528                        .build_query_scalar::<#scalar_ty>()
2529                        .fetch_one(executor)
2530                        .await
2531                }
2532
2533                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2534                where
2535                    E: sqlx::Executor<'e, Database = #db>,
2536                {
2537                    self.inner
2538                        .build_query_scalar::<#scalar_ty>()
2539                        .fetch_optional(executor)
2540                        .await
2541                }
2542
2543                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2544                where
2545                    E: sqlx::Executor<'e, Database = #db>,
2546                {
2547                    self.inner.build().execute(executor).await
2548                }
2549            }
2550        } else {
2551            let model = model_opt.as_ref().unwrap();
2552            quote! {
2553                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2554                where
2555                    E: sqlx::Executor<'e, Database = #db>,
2556                {
2557                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2558                }
2559
2560                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2561                where
2562                    E: sqlx::Executor<'e, Database = #db>,
2563                {
2564                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2565                }
2566
2567                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2568                where
2569                    E: sqlx::Executor<'e, Database = #db>,
2570                {
2571                    self.inner
2572                        .build_query_as::<#model>()
2573                        .fetch_optional(executor)
2574                        .await
2575                }
2576
2577                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2578                where
2579                    E: sqlx::Executor<'e, Database = #db>,
2580                {
2581                    self.inner.build().execute(executor).await
2582                }
2583            }
2584        };
2585
2586        let final_type: TokenStream2 = if let Some(model) = model_opt {
2587            if let Some(scalar_ty) = scalar_model_ty {
2588                quote! { #scalar_ty }
2589            } else {
2590                quote! { #model }
2591            }
2592        } else {
2593            quote! {}
2594        };
2595        let trait_impl = if model_opt.is_none() {
2596            quote! {
2597                impl<'args> sql_forge::SqlForgeQueryExecute
2598                    for #query_ident<'args>
2599                {
2600                    type Db = #db;
2601
2602                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2603                    where
2604                        Self: Sized + 'e,
2605                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2606                        #db: 'e,
2607                    {
2608                        #query_ident::execute(self, executor)
2609                    }
2610                }
2611            }
2612        } else {
2613            quote! {
2614                impl<'args> sql_forge::SqlForgeQuery<#final_type>
2615                    for #query_ident<'args>
2616                {
2617                    type Db = #db;
2618
2619                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2620                    where
2621                        Self: Sized + 'e,
2622                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2623                        #db: 'e,
2624                    {
2625                        #query_ident::fetch_all(self, executor)
2626                    }
2627
2628                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2629                    where
2630                        Self: Sized + 'e,
2631                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2632                        #db: 'e,
2633                    {
2634                        #query_ident::fetch_one(self, executor)
2635                    }
2636
2637                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2638                    where
2639                        Self: Sized + 'e,
2640                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2641                        #db: 'e,
2642                    {
2643                        #query_ident::fetch_optional(self, executor)
2644                    }
2645
2646                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2647                    where
2648                        Self: Sized + 'e,
2649                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2650                        #db: 'e,
2651                    {
2652                        #query_ident::execute(self, executor)
2653                    }
2654                }
2655            }
2656        };
2657
2658        generated_query_defs.push(quote! {
2659            struct #query_ident<'args> {
2660                inner: sqlx::QueryBuilder<'args, #db>,
2661            }
2662
2663            impl<'args> #query_ident<'args> {
2664                #exec_methods
2665            }
2666
2667            #trait_impl
2668        });
2669
2670        generated_query_values.push(quote! {
2671            #( #runtime_param_bindings )*
2672            #( #flag_bindings )*
2673            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2674            #( #runtime_steps )*
2675            let #query_value_ident = #query_ident { inner: __builder };
2676        });
2677
2678        if let Some(key) = result_key {
2679            let method_ident = format_ident!("{}", key);
2680            group_field_defs.push(quote! {
2681                #method_ident: #query_ident<'args>
2682            });
2683            group_field_tys.push(quote! { #query_ident<'args> });
2684            group_method_defs.push(quote! {
2685                pub fn #method_ident(self) -> #query_ident<'args> {
2686                    self.#method_ident
2687                }
2688            });
2689
2690            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2691            group_trait_impls.push(quote! {
2692                struct #key_ty_ident;
2693
2694                impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2695                    type Query = #query_ident<'args>;
2696
2697                    fn get(self, _: #key_ty_ident) -> Self::Query {
2698                        self.#method_ident
2699                    }
2700                }
2701            });
2702            group_field_idents.push(method_ident);
2703        }
2704    }
2705
2706    // ---- Phase 8: Emit the final token stream ----
2707    let validator_tokens = quote! {
2708        let _sql_forge_validator = || {
2709            #( #validator_param_bindings )*
2710            #( #grouped_validator_invocations )*
2711        };
2712    };
2713
2714    if !is_grouped_result {
2715        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2716        return quote! {
2717            {
2718                #validator_tokens
2719                #( #generated_query_defs )*
2720                #( #generated_query_values )*
2721                #single_query_value_ident
2722            }
2723        }
2724        .into();
2725    }
2726
2727    let group_field_inits: Vec<TokenStream2> = result_cases
2728        .iter()
2729        .filter_map(|(key, _, _)| key.as_ref())
2730        .map(|key| {
2731            let method_ident = format_ident!("{}", key);
2732            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2733            quote! { #method_ident: #query_value_ident }
2734        })
2735        .collect();
2736
2737    quote! {
2738        {
2739            #validator_tokens
2740
2741            #( #generated_query_defs )*
2742            #( #generated_query_values )*
2743
2744            struct __SqlForgeQueryGroup<'args> {
2745                #( #group_field_defs, )*
2746            }
2747
2748            impl<'args> __SqlForgeQueryGroup<'args> {
2749                #( #group_method_defs )*
2750
2751                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2752                    ( #( self.#group_field_idents ),* )
2753                }
2754            }
2755
2756            impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2757                type Db = #db;
2758            }
2759
2760            #( #group_trait_impls )*
2761
2762            __SqlForgeQueryGroup {
2763                #( #group_field_inits, )*
2764            }
2765        }
2766    }
2767    .into()
2768}
2769
2770/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2771/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2772///
2773/// ```rust,ignore
2774/// use sql_forge::db_type;
2775///
2776/// pub type AppDb = db_type!();
2777/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2778/// ```
2779///
2780/// Priority:
2781/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2782/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2783#[proc_macro]
2784pub fn db_type(input: TokenStream) -> TokenStream {
2785    if !input.is_empty() {
2786        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2787            .to_compile_error()
2788            .into();
2789    }
2790
2791    match resolve_db_from_env() {
2792        Ok(db) => quote! { #db }.into(),
2793        Err(msg) => syn::Error::new(Span::call_site(), msg)
2794            .to_compile_error()
2795            .into(),
2796    }
2797}
2798
2799/// Marks a single-value tuple struct as a transparent wrapper for use with
2800/// `sql_forge!` parameters.
2801///
2802/// Expands to `#[derive(sqlx::Type)]` + `#[sqlx(transparent)]` (needed for
2803/// all database backends so the type implements `sqlx::Encode` + `sqlx::Type`)
2804/// and additionally implements `SqlForgeValidatorValue<InnerType>`, which is
2805/// **required for PostgreSQL** to pass compile-time parameter validation in
2806/// `query_as!`. MySQL and SQLite do not use the trait.
2807///
2808/// ```rust,ignore
2809/// #[derive(Debug, PartialEq, Eq)]
2810/// #[sql_forge_transparent]
2811/// struct UserId(pub i64);
2812/// ```
2813#[proc_macro_attribute]
2814pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2815    let input: ItemStruct = match syn::parse(item) {
2816        Ok(v) => v,
2817        Err(err) => return err.to_compile_error().into(),
2818    };
2819
2820    let struct_name = &input.ident;
2821    let inner_type = match &input.fields {
2822        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2823        _ => {
2824            return syn::Error::new(
2825                input.span(),
2826                "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2827            )
2828            .to_compile_error()
2829            .into();
2830        }
2831    };
2832
2833    let attrs = input.attrs;
2834    let generics = &input.generics;
2835    let vis = &input.vis;
2836    let struct_token = input.struct_token;
2837    let semi_token = input.semi_token;
2838    let fields = &input.fields;
2839
2840    let expanded = quote! {
2841        #( #attrs )*
2842        #[derive(sqlx::Type)]
2843        #[sqlx(transparent)]
2844        #vis #struct_token #struct_name #generics #fields #semi_token
2845
2846        impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2847            fn sql_forge_validator_value(&self) -> #inner_type {
2848                self.0.clone()
2849            }
2850        }
2851    };
2852
2853    expanded.into()
2854}