Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17    Stmt, Token, Type,
18};
19
20// =============================================================================
21// Input types
22// =============================================================================
23
24mod kw {
25    syn::custom_keyword!(scalar);
26}
27
28/// A `:name = expr` parameter binding.
29#[derive(Clone)]
30struct ParamAssign {
31    name: Ident,
32    expr: Expr,
33}
34
35impl Parse for ParamAssign {
36    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37        input.parse::<Token![:]>()?;
38        let name: Ident = input.parse()?;
39        input.parse::<Token![=]>()?;
40        let expr: Expr = input.parse()?;
41        Ok(Self { name, expr })
42    }
43}
44
45/// A single section value: SQL string + optional local parameter bindings.
46#[derive(Clone)]
47struct SectionFragment {
48    sql: String,
49    span: Span,
50    params: ParamsSource,
51}
52
53/// One arm in a `match { ... }` inside a section assignment.
54#[derive(Clone)]
55struct SectionMatchArm {
56    pat: Pat,
57    guard: Option<Expr>,
58    value: SectionValue,
59}
60
61/// The right-hand side of a section `#name = ...` assignment.
62#[derive(Clone)]
63enum SectionValue {
64    /// A plain string (or `(string, params)` tuple).
65    Single(SectionFragment),
66    /// A tuple of values, one per section when using `#(a, b) = ...`.
67    Grouped(Vec<SectionValue>),
68    /// A `match expr { arm => ..., arm => ... }` expression.
69    Match {
70        expr: Expr,
71        arms: Vec<SectionMatchArm>,
72    },
73}
74
75/// A full `#name = value` or `#(a, b) = value` section assignment.
76#[derive(Clone)]
77struct SectionAssign {
78    names: Vec<Ident>,
79    value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        input.parse::<Token![#]>()?;
85
86        // Parse `#(a, b)` for grouped sections, or `#name` for single.
87        let names = if input.peek(syn::token::Paren) {
88            let content;
89            syn::parenthesized!(content in input);
90            let mut out = Vec::new();
91            while !content.is_empty() {
92                out.push(content.parse::<Ident>()?);
93                if content.is_empty() {
94                    break;
95                }
96                content.parse::<Token![,]>()?;
97            }
98            if out.is_empty() {
99                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100            }
101            out
102        } else {
103            vec![input.parse::<Ident>()?]
104        };
105
106        input.parse::<Token![=]>()?;
107        let value = parse_section_value(input, names.len())?;
108        Ok(Self { names, value })
109    }
110}
111
112/// The fully-parsed macro invocation.
113struct SqlForgeInput {
114    db: Option<Type>,
115    result: ResultSpec,
116    force_scalar: bool,
117    sql: SqlTemplate,
118    params: ParamsSource,
119    sections: Vec<SectionAssign>,
120    batch: Option<Expr>,
121}
122
123/// One entry in a result map `(>key = Model, ...)`.
124#[derive(Clone)]
125struct ResultAssign {
126    name: Ident,
127    model: Type,
128    force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
134    None,
135    /// e.g. `sql_forge!(User, ...)`
136    Single(Box<Type>),
137    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
138    Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143    None,
144    /// `( :name = expr, ... )`
145    Map(Vec<ParamAssign>),
146    /// A struct expression whose fields are matched to `:name` placeholders.
147    Struct(Box<Expr>),
148}
149
150/// The SQL template. Only string literals are supported.
151enum SqlTemplate {
152    Literal(LitStr),
153}
154
155impl SqlTemplate {
156    fn span(&self) -> Span {
157        match self {
158            Self::Literal(lit) => lit.span(),
159        }
160    }
161
162    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
163    fn into_segments(self) -> Result<Vec<Segment>, String> {
164        match self {
165            Self::Literal(lit) => parse_literal_segments(&lit.value()),
166        }
167    }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171    if input.peek(LitStr) {
172        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173    } else {
174        Err(input.error("sql_forge!: SQL template must be a string literal"))
175    }
176}
177
178/// One piece of the parsed SQL template.
179#[derive(Clone)]
180enum Segment {
181    /// Plain SQL text (may contain `:param` placeholders).
182    Text(String),
183    /// A `{#section_name}` placeholder.
184    Section { name: String },
185    /// A `{( ... )}` batch value template (repeated per batch item).
186    Batch { parts: Vec<TextPart> },
187}
188
189/// A fragment of SQL text after splitting on `:param`.
190#[derive(Clone)]
191enum TextPart {
192    /// Literal SQL text.
193    Lit(String),
194    /// A `:param` placeholder.
195    Param { name: String, is_list: bool },
196}
197
198// =============================================================================
199// Parsing helpers
200// =============================================================================
201
202/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
203enum MapKind {
204    Results,
205    Params,
206    Sections,
207}
208
209// =============================================================================
210// Section value parsing (string fragments, match, grouped tuples)
211// =============================================================================
212
213fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
214    let fork = input.fork();
215    let content;
216    syn::parenthesized!(content in fork);
217
218    if content.is_empty() {
219        return Err(input.error("sql_forge!: map argument cannot be empty"));
220    }
221
222    if content.peek(Token![>]) {
223        Ok(Some(MapKind::Results))
224    } else if content.peek(Token![:]) {
225        Ok(Some(MapKind::Params))
226    } else if content.peek(Token![#]) {
227        Ok(Some(MapKind::Sections))
228    } else {
229        Ok(None)
230    }
231}
232
233impl Parse for ResultAssign {
234    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
235        input.parse::<Token![>]>()?;
236        let name: Ident = input.parse()?;
237        input.parse::<Token![=]>()?;
238        let (force_scalar, model) = if input.peek(kw::scalar) {
239            input.parse::<kw::scalar>()?;
240            (true, input.parse::<Type>()?)
241        } else {
242            (false, input.parse::<Type>()?)
243        };
244        Ok(Self {
245            name,
246            model,
247            force_scalar,
248        })
249    }
250}
251
252fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
253    let content;
254    syn::parenthesized!(content in input);
255
256    let mut results = Vec::new();
257    while !content.is_empty() {
258        results.push(content.parse::<ResultAssign>()?);
259        if content.is_empty() {
260            break;
261        }
262        content.parse::<Token![,]>()?;
263    }
264
265    if results.is_empty() {
266        return Err(input.error("sql_forge!: result map cannot be empty"));
267    }
268
269    Ok(results)
270}
271
272fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
273    let content;
274    syn::parenthesized!(content in input);
275
276    let mut params = Vec::new();
277    while !content.is_empty() {
278        params.push(content.parse::<ParamAssign>()?);
279        if content.is_empty() {
280            break;
281        }
282        content.parse::<Token![,]>()?;
283    }
284
285    Ok(params)
286}
287
288fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
289    let content;
290    syn::parenthesized!(content in input);
291
292    let mut sections = Vec::new();
293    while !content.is_empty() {
294        sections.push(content.parse::<SectionAssign>()?);
295        if content.is_empty() {
296            break;
297        }
298        content.parse::<Token![,]>()?;
299    }
300
301    Ok(sections)
302}
303
304fn parse_params_source_expr(
305    input: ParseStream<'_>,
306    allow_sections: bool,
307) -> syn::Result<ParamsSource> {
308    if input.peek(syn::token::Paren) {
309        match detect_parenthesized_map_kind(input)? {
310            Some(MapKind::Results) => Err(input
311                .error("sql_forge!: result maps are only allowed as the macro result argument")),
312            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
313            Some(MapKind::Sections) if allow_sections => {
314                Err(input.error("sql_forge!: section maps are not allowed here"))
315            }
316            Some(MapKind::Sections) => Err(input.error(
317                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
318            )),
319            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
320        }
321    } else {
322        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
323    }
324}
325
326fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327    if input.peek(syn::token::Paren) {
328        let fork = input.fork();
329        let content;
330        syn::parenthesized!(content in fork);
331
332        if let Ok(first_expr) = content.parse::<Expr>() {
333            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334                let _ = parse_params_source_expr(&content, false)?;
335                if content.peek(Token![,]) {
336                    content.parse::<Token![,]>()?;
337                }
338                if content.is_empty() {
339                    let content;
340                    syn::parenthesized!(content in input);
341                    let first_expr: Expr = content.parse()?;
342                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343                        input.error("sql_forge!: section tuple must start with a string literal")
344                    })?;
345                    let span = first_expr.span();
346                    content.parse::<Token![,]>()?;
347                    let params = parse_params_source_expr(&content, false)?;
348                    if content.peek(Token![,]) {
349                        content.parse::<Token![,]>()?;
350                    }
351                    if !content.is_empty() {
352                        return Err(content.error(
353                            "sql_forge!: unexpected tokens after section-local parameter source",
354                        ));
355                    }
356                    return Ok(SectionFragment { sql, span, params });
357                }
358            }
359        }
360    }
361
362    let expr: Expr = input.parse()?;
363    let sql = extract_lit_str(&expr).ok_or_else(|| {
364        input
365            .error("sql_forge!: section values must be string literals or (string literal, params)")
366    })?;
367    Ok(SectionFragment {
368        sql,
369        span: expr.span(),
370        params: ParamsSource::None,
371    })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
375    if input.peek(Token![match]) {
376        input.parse::<Token![match]>()?;
377        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
378        let content;
379        syn::braced!(content in input);
380        let mut arms = Vec::new();
381        while !content.is_empty() {
382            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
383            let guard = if content.peek(Token![if]) {
384                content.parse::<Token![if]>()?;
385                Some(content.parse::<Expr>()?)
386            } else {
387                None
388            };
389            content.parse::<Token![=>]>()?;
390            let value = parse_section_value(&content, width)?;
391            if content.peek(Token![,]) {
392                content.parse::<Token![,]>()?;
393            }
394            arms.push(SectionMatchArm { pat, guard, value });
395        }
396        return Ok(SectionValue::Match { expr, arms });
397    }
398
399    if width == 1 {
400        return Ok(SectionValue::Single(parse_section_fragment(input)?));
401    }
402
403    let content;
404    syn::parenthesized!(content in input);
405    let mut items = Vec::new();
406    while !content.is_empty() {
407        items.push(parse_section_value(&content, 1)?);
408        if content.is_empty() {
409            break;
410        }
411        content.parse::<Token![,]>()?;
412    }
413
414    if items.len() != width {
415        return Err(input.error(format!(
416            "sql_forge!: grouped section value must provide exactly {} items",
417            width,
418        )));
419    }
420
421    Ok(SectionValue::Grouped(items))
422}
423
424// =============================================================================
425// Top-level macro input parsing (SqlForgeInput::parse)
426// =============================================================================
427
428impl Parse for SqlForgeInput {
429    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
430        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
431            let sql = parse_sql_template(input)?;
432            (None, ResultSpec::None, false, sql)
433        } else if input.peek(kw::scalar) {
434            input.parse::<kw::scalar>()?;
435            let model: Type = input.parse()?;
436            input.parse::<Token![,]>()?;
437            let sql = parse_sql_template(input)?;
438            (None, ResultSpec::Single(Box::new(model)), true, sql)
439        } else if input.peek(syn::token::Paren) {
440            let result_map_kind = detect_parenthesized_map_kind(input)?;
441            match result_map_kind {
442                Some(MapKind::Results) => {
443                    let result = ResultSpec::Group(parse_result_map(input)?);
444                    input.parse::<Token![,]>()?;
445                    let sql = parse_sql_template(input)?;
446                    (None, result, false, sql)
447                }
448                _ => {
449                    return Err(input.error(
450                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
451                    ));
452                }
453            }
454        } else {
455            let first_ty: Type = input.parse()?;
456            input.parse::<Token![,]>()?;
457
458            if input.peek(LitStr) {
459                let model = first_ty;
460                let sql = parse_sql_template(input)?;
461                (None, ResultSpec::Single(Box::new(model)), false, sql)
462            } else if input.peek(kw::scalar) {
463                input.parse::<kw::scalar>()?;
464                let model: Type = input.parse()?;
465                input.parse::<Token![,]>()?;
466                let sql = parse_sql_template(input)?;
467                (
468                    Some(first_ty),
469                    ResultSpec::Single(Box::new(model)),
470                    true,
471                    sql,
472                )
473            } else if input.peek(syn::token::Paren)
474                && matches!(
475                    detect_parenthesized_map_kind(input)?,
476                    Some(MapKind::Results)
477                )
478            {
479                let result = ResultSpec::Group(parse_result_map(input)?);
480                input.parse::<Token![,]>()?;
481                let sql = parse_sql_template(input)?;
482                (Some(first_ty), result, false, sql)
483            } else {
484                let db = Some(first_ty);
485                let model: Type = input.parse()?;
486                input.parse::<Token![,]>()?;
487                let sql = parse_sql_template(input)?;
488                (db, ResultSpec::Single(Box::new(model)), false, sql)
489            }
490        };
491
492        let mut batch = None;
493        let mut params = ParamsSource::None;
494        let mut sections = Vec::new();
495        let mut seen_params = false;
496        let mut seen_sections = false;
497
498        if input.parse::<Token![,]>().is_ok() {
499            while !input.is_empty() {
500                if input.peek(Token![..]) {
501                    if batch.is_some() {
502                        return Err(
503                            input.error("sql_forge!: only one batch source argument is allowed")
504                        );
505                    }
506                    input.parse::<Token![..]>()?;
507                    batch = Some(input.parse::<Expr>()?);
508                } else if input.peek(syn::token::Paren) {
509                    match detect_parenthesized_map_kind(input)? {
510                        Some(MapKind::Results) => {
511                            return Err(input.error(
512                                "sql_forge!: result maps are only allowed as the macro result argument",
513                            ));
514                        }
515                        Some(MapKind::Params) => {
516                            if seen_params {
517                                return Err(
518                                    input.error("sql_forge!: only one parameter source is allowed")
519                                );
520                            }
521                            params = ParamsSource::Map(parse_param_map(input)?);
522                            seen_params = true;
523                        }
524                        Some(MapKind::Sections) => {
525                            if seen_sections {
526                                return Err(
527                                    input.error("sql_forge!: duplicate section map argument")
528                                );
529                            }
530                            sections = parse_section_map(input)?;
531                            seen_sections = true;
532                        }
533                        None => {
534                            if seen_params {
535                                return Err(
536                                    input.error("sql_forge!: only one parameter source is allowed")
537                                );
538                            }
539                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
540                            seen_params = true;
541                        }
542                    }
543                } else {
544                    if seen_params {
545                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
546                    }
547                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548                    seen_params = true;
549                }
550
551                if input.parse::<Token![,]>().is_ok() {
552                    continue;
553                }
554                break;
555            }
556        }
557
558        if !input.is_empty() {
559            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
560        }
561
562        Ok(Self {
563            db,
564            result,
565            force_scalar,
566            sql,
567            params,
568            sections,
569            batch,
570        })
571    }
572}
573
574// =============================================================================
575// Database type resolution
576// =============================================================================
577
578fn resolve_db_from_env() -> Result<Type, String> {
579    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
580        return syn::parse_str::<Type>(&val).map_err(|err| {
581            format!(
582                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
583                val, err
584            )
585        });
586    }
587
588    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
589        Ok(d) => d,
590        Err(_) => {
591            return Err(
592                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
593                 or configure [package.metadata.sql_forge] in Cargo.toml"
594                    .to_string(),
595            );
596        }
597    };
598    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
599
600    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
601        format!(
602            "sql_forge!: failed to read {}: {}",
603            manifest_path.display(),
604            err
605        )
606    })?;
607
608    let value: toml::Value = toml::from_str(&cargo_toml)
609        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
610
611    let db_str = value
612        .get("package")
613        .and_then(|v| v.get("metadata"))
614        .and_then(|v| v.get("sql_forge"))
615        .and_then(|v| v.get("db"))
616        .and_then(|v| v.as_str())
617        .ok_or({
618            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
619             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
620        })?;
621
622    syn::parse_str::<Type>(db_str).map_err(|err| {
623        format!(
624            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
625            db_str, err
626        )
627    })
628}
629
630fn uses_dollar_params(db: &Type) -> bool {
631    let Type::Path(type_path) = db else {
632        return false;
633    };
634    type_path
635        .path
636        .segments
637        .last()
638        .is_some_and(|s| s.ident == "Postgres")
639}
640
641fn is_builtin_scalar_type(ty: &Type) -> bool {
642    let Type::Path(type_path) = ty else {
643        return false;
644    };
645
646    if type_path.qself.is_some()
647        || type_path.path.leading_colon.is_some()
648        || type_path.path.segments.len() != 1
649    {
650        return false;
651    }
652
653    let ident = &type_path.path.segments[0].ident;
654    ident == "i8"
655        || ident == "i16"
656        || ident == "i32"
657        || ident == "i64"
658        || ident == "isize"
659        || ident == "u8"
660        || ident == "u16"
661        || ident == "u32"
662        || ident == "u64"
663        || ident == "usize"
664        || ident == "f32"
665        || ident == "f64"
666        || ident == "bool"
667        || ident == "String"
668}
669
670fn scalar_output_type(model: &Type) -> Option<&Type> {
671    if is_builtin_scalar_type(model) {
672        return Some(model);
673    }
674    None
675}
676
677fn push_text_segment(out: &mut Vec<Segment>, text: String) {
678    if text.is_empty() {
679        return;
680    }
681    match out.last_mut() {
682        Some(Segment::Text(existing)) => existing.push_str(&text),
683        _ => out.push(Segment::Text(text)),
684    }
685}
686
687fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
688    let mut out = Vec::new();
689    let mut text = String::new();
690    let mut chars = sql.chars().peekable();
691
692    while let Some(ch) = chars.next() {
693        if ch != '{' {
694            text.push(ch);
695            continue;
696        }
697
698        if chars.peek() == Some(&'(') {
699            push_text_segment(&mut out, std::mem::take(&mut text));
700
701            let mut paren_depth = 0u32;
702            let mut content = String::new();
703            let mut found_close = false;
704            for ch in chars.by_ref() {
705                if ch == '{' {
706                    return Err(
707                        "sql_forge!: nested braces not allowed inside batch section".to_string()
708                    );
709                }
710                if ch == '}' {
711                    if paren_depth != 0 {
712                        return Err(
713                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
714                                .to_string(),
715                        );
716                    }
717                    found_close = true;
718                    break;
719                }
720                if ch == '(' {
721                    paren_depth += 1;
722                } else if ch == ')' {
723                    if paren_depth == 0 {
724                        return Err(
725                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
726                                .to_string(),
727                        );
728                    }
729                    paren_depth -= 1;
730                }
731                content.push(ch);
732            }
733            if !found_close {
734                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
735            }
736            let parts = parse_text_parts(&content);
737            for part in &parts {
738                if let TextPart::Param { is_list: true, .. } = part {
739                    return Err(
740                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
741                         batch sections; use plain parameters (:name) instead"
742                            .to_string(),
743                    );
744                }
745            }
746            out.push(Segment::Batch { parts });
747            continue;
748        }
749
750        if chars.peek() != Some(&'#') {
751            text.push(ch);
752            continue;
753        }
754
755        chars.next();
756        push_text_segment(&mut out, std::mem::take(&mut text));
757
758        let mut name = String::new();
759        loop {
760            let Some(next) = chars.next() else {
761                return Err("sql_forge!: section placeholder without closing }".to_string());
762            };
763            if next == '}' {
764                break;
765            }
766            name.push(next);
767        }
768
769        if name.is_empty() {
770            return Err("sql_forge!: empty section placeholder name".to_string());
771        }
772
773        out.push(Segment::Section { name });
774    }
775
776    push_text_segment(&mut out, text);
777    Ok(out)
778}
779
780// =============================================================================
781// SQL template parsing: {#sections} and :param placeholders
782// =============================================================================
783
784fn is_ident_start(ch: char) -> bool {
785    ch == '_' || ch.is_ascii_alphabetic()
786}
787
788fn is_ident_continue(ch: char) -> bool {
789    is_ident_start(ch) || ch.is_ascii_digit()
790}
791
792fn sanitize_backticked_alias_ident(content: &str) -> String {
793    let mut split_at = content.len();
794    for (idx, ch) in content.char_indices() {
795        if ch == '!' || ch == '?' || ch == ':' {
796            split_at = idx;
797            break;
798        }
799    }
800
801    if split_at == content.len() {
802        return content.to_string();
803    }
804
805    let base = content[..split_at].trim_end();
806    if base.is_empty() {
807        content.to_string()
808    } else {
809        base.to_string()
810    }
811}
812
813fn sanitize_runtime_sql_text(text: &str) -> String {
814    let mut out = String::with_capacity(text.len());
815    let mut chars = text.chars().peekable();
816
817    while let Some(ch) = chars.next() {
818        if ch != '`' {
819            out.push(ch);
820            continue;
821        }
822
823        let mut content = String::new();
824        let mut closed = false;
825
826        for next in chars.by_ref() {
827            if next == '`' {
828                closed = true;
829                break;
830            }
831            content.push(next);
832        }
833
834        if closed {
835            out.push('`');
836            out.push_str(&sanitize_backticked_alias_ident(&content));
837            out.push('`');
838        } else {
839            out.push('`');
840            out.push_str(&content);
841            break;
842        }
843    }
844
845    out
846}
847
848fn parse_text_parts(text: &str) -> Vec<TextPart> {
849    let mut parts = Vec::new();
850    let mut last = 0usize;
851    let mut iter = text.char_indices().peekable();
852
853    while let Some((idx, ch)) = iter.next() {
854        if ch != ':' {
855            continue;
856        }
857
858        let Some(&(next_idx, next_ch)) = iter.peek() else {
859            continue;
860        };
861
862        if !is_ident_start(next_ch) {
863            continue;
864        }
865
866        if text[..idx].ends_with(':') {
867            continue;
868        }
869
870        if last < idx {
871            parts.push(TextPart::Lit(text[last..idx].to_string()));
872        }
873
874        iter.next();
875
876        let mut name = String::new();
877        name.push(next_ch);
878        let mut end = next_idx + next_ch.len_utf8();
879
880        while let Some(&(j, c)) = iter.peek() {
881            if is_ident_continue(c) {
882                name.push(c);
883                end = j + c.len_utf8();
884                iter.next();
885            } else {
886                break;
887            }
888        }
889
890        let mut is_list = false;
891        if text[end..].starts_with("[]") {
892            is_list = true;
893            end += 2;
894        }
895
896        parts.push(TextPart::Param { name, is_list });
897        last = end;
898    }
899
900    if last < text.len() {
901        parts.push(TextPart::Lit(text[last..].to_string()));
902    }
903
904    parts
905}
906
907fn render_validator_text(
908    text: &str,
909    use_dollar_params: bool,
910    param_offset: &mut usize,
911    list_count: usize,
912) -> (String, Vec<(String, bool)>) {
913    let mut out_sql = String::new();
914    let mut occurrences = Vec::new();
915
916    for part in parse_text_parts(text) {
917        match part {
918            TextPart::Lit(lit) => out_sql.push_str(&lit),
919            TextPart::Param { name, is_list } => {
920                if is_list && list_count > 1 {
921                    let slots: Vec<String> = if use_dollar_params {
922                        (0..list_count)
923                            .map(|i| format!("${}", *param_offset + i + 1))
924                            .collect()
925                    } else {
926                        (0..list_count).map(|_| "?".to_string()).collect()
927                    };
928                    if use_dollar_params {
929                        *param_offset += list_count;
930                    }
931                    out_sql.push_str(&slots.join(", "));
932                } else if use_dollar_params {
933                    *param_offset += 1;
934                    write!(out_sql, "${}", *param_offset).unwrap();
935                } else {
936                    out_sql.push('?');
937                }
938                occurrences.push((name, is_list));
939            }
940        }
941    }
942
943    (out_sql, occurrences)
944}
945
946fn strip_expr(expr: &Expr) -> &Expr {
947    match expr {
948        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
949        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
950        Expr::Block(ExprBlock { block, .. }) => {
951            if block.stmts.len() != 1 {
952                return expr;
953            }
954            match &block.stmts[0] {
955                Stmt::Expr(inner, None) => strip_expr(inner),
956                _ => expr,
957            }
958        }
959        _ => expr,
960    }
961}
962
963fn extract_lit_str(expr: &Expr) -> Option<String> {
964    match strip_expr(expr) {
965        Expr::Lit(ExprLit {
966            lit: Lit::Str(lit), ..
967        }) => Some(lit.value()),
968        _ => None,
969    }
970}
971
972// =============================================================================
973// Preprocessing: {>key} compile-time result flags
974// =============================================================================
975
976fn result_flag_ident(name: &str) -> syn::Ident {
977    format_ident!("__enhanced_result_flag_{}", name)
978}
979
980/// Replaces `{>key}` tokens inside braced groups with `__enhanced_result_flag_key`
981/// identifiers. This is a preprocessing step so that the rest of the parser sees
982/// plain identifiers instead of braced groups it does not understand.
983fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
984    fn walk(stream: TokenStream2) -> TokenStream2 {
985        let mut out = TokenStream2::new();
986        let iter = stream.into_iter().peekable();
987
988        for token in iter {
989            match token {
990                TokenTree::Group(group) => {
991                    if group.delimiter() == Delimiter::Brace {
992                        let mut inner = group.stream().into_iter();
993                        let first = inner.next();
994                        let second = inner.next();
995                        let third = inner.next();
996
997                        if let (
998                            Some(TokenTree::Punct(p)),
999                            Some(TokenTree::Ident(name_ident)),
1000                            None,
1001                        ) = (first, second, third)
1002                        {
1003                            if p.as_char() == '>' {
1004                                let ident = result_flag_ident(&name_ident.to_string());
1005                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1006                                continue;
1007                            }
1008                        }
1009                    }
1010
1011                    let new_inner = walk(group.stream());
1012                    let mut new_group = Group::new(group.delimiter(), new_inner);
1013                    new_group.set_span(group.span());
1014                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1015                }
1016                other => out.extend(std::iter::once(other)),
1017            }
1018        }
1019
1020        out
1021    }
1022
1023    walk(input)
1024}
1025
1026fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1027    keys.iter()
1028        .map(|key| {
1029            let ident = result_flag_ident(key);
1030            let enabled = Some(key.as_str()) == active_key;
1031            quote! { let #ident: bool = #enabled; }
1032        })
1033        .collect()
1034}
1035
1036fn transpose_section_case_matrix(
1037    case_matrix: Vec<Vec<SectionFragment>>,
1038    width: usize,
1039) -> Result<Vec<Vec<SectionFragment>>, String> {
1040    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1041
1042    for row in case_matrix {
1043        if row.len() != width {
1044            return Err(
1045                "sql_forge!: grouped sections must return one item per section".to_string(),
1046            );
1047        }
1048        for (section_idx, fragment) in row.into_iter().enumerate() {
1049            per_section[section_idx].push(fragment);
1050        }
1051    }
1052
1053    Ok(per_section)
1054}
1055
1056fn collect_section_case_matrix(
1057    value: SectionValue,
1058    width: usize,
1059    active_key: Option<&str>,
1060) -> Result<Vec<Vec<SectionFragment>>, String> {
1061    match value {
1062        SectionValue::Single(fragment) => {
1063            if width != 1 {
1064                return Err(
1065                    "sql_forge!: grouped sections must return one item per section".to_string(),
1066                );
1067            }
1068            Ok(vec![vec![fragment]])
1069        }
1070        SectionValue::Grouped(values) => {
1071            if values.len() != width {
1072                return Err(
1073                    "sql_forge!: grouped sections must return one item per section".to_string(),
1074                );
1075            }
1076
1077            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1078            let mut nmax = 1usize;
1079
1080            for value in values {
1081                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1082                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1083                for mut row in item_matrix {
1084                    let fragment = row.pop().ok_or_else(|| {
1085                        "sql_forge!: grouped sections must return one item per section".to_string()
1086                    })?;
1087                    if !row.is_empty() {
1088                        return Err(
1089                            "sql_forge!: grouped sections must return one item per section"
1090                                .to_string(),
1091                        );
1092                    }
1093                    item_variants.push(fragment);
1094                }
1095                if item_variants.is_empty() {
1096                    return Err("sql_forge!: section match must have at least one arm".to_string());
1097                }
1098                nmax = nmax.max(item_variants.len());
1099                variants_by_section.push(item_variants);
1100            }
1101
1102            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1103            for case_idx in 0..nmax {
1104                let mut row = Vec::<SectionFragment>::with_capacity(width);
1105                for variants in &variants_by_section {
1106                    row.push(variants[case_idx % variants.len()].clone());
1107                }
1108                case_matrix.push(row);
1109            }
1110
1111            Ok(case_matrix)
1112        }
1113        SectionValue::Match { expr, arms } => {
1114            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1115
1116            if let Some(key) = expr_result_flag_key(&expr) {
1117                let target = active_key == Some(key.as_str());
1118                for arm in arms {
1119                    if arm.guard.is_none() {
1120                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1121                            continue;
1122                        }
1123                    }
1124                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1125                    wrap_section_case_matrix_for_match_arm(
1126                        &mut arm_cases,
1127                        &expr,
1128                        &arm.pat,
1129                        arm.guard.as_ref(),
1130                    );
1131                    case_matrix.extend(arm_cases);
1132                }
1133            } else {
1134                for arm in arms {
1135                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1136                    wrap_section_case_matrix_for_match_arm(
1137                        &mut arm_cases,
1138                        &expr,
1139                        &arm.pat,
1140                        arm.guard.as_ref(),
1141                    );
1142                    case_matrix.extend(arm_cases);
1143                }
1144            }
1145
1146            if case_matrix.is_empty() {
1147                return Err("sql_forge!: section match must have at least one arm".to_string());
1148            }
1149
1150            Ok(case_matrix)
1151        }
1152    }
1153}
1154
1155/// This function rewrites a section-local validator expression so it stays inside the same
1156/// match arm scope that originally introduced it
1157fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1158    let match_expr = match_expr.clone();
1159    let pat = pat.clone();
1160    let pattern_binds_values = match &pat {
1161        Pat::Ident(_) => true,
1162        Pat::Or(pat_or) => pat_or
1163            .cases
1164            .iter()
1165            .any(|case| matches!(case, Pat::Ident(_))),
1166        Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1167        Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1168        Pat::Slice(pat_slice) => pat_slice
1169            .elems
1170            .iter()
1171            .any(|elem| matches!(elem, Pat::Ident(_))),
1172        Pat::Struct(pat_struct) => pat_struct
1173            .fields
1174            .iter()
1175            .any(|field| matches!(*field.pat, Pat::Ident(_))),
1176        Pat::Tuple(pat_tuple) => pat_tuple
1177            .elems
1178            .iter()
1179            .any(|elem| matches!(elem, Pat::Ident(_))),
1180        Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1181            .elems
1182            .iter()
1183            .any(|elem| matches!(elem, Pat::Ident(_))),
1184        Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1185        _ => false,
1186    };
1187
1188    if pattern_binds_values {
1189        if let Some(guard) = guard.cloned() {
1190            parse_quote! {
1191                match &(#match_expr) {
1192                    #pat if #guard => { #expr },
1193                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1194                }
1195            }
1196        } else {
1197            parse_quote! {
1198                match &(#match_expr) {
1199                    #pat => { #expr },
1200                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1201                }
1202            }
1203        }
1204    } else if let Some(guard) = guard.cloned() {
1205        parse_quote! {
1206            match &(#match_expr) {
1207                #pat if #guard => { &(#expr) },
1208                _ => unreachable!("sql_forge!: validator arm mismatch"),
1209            }
1210        }
1211    } else {
1212        parse_quote! {
1213            match &(#match_expr) {
1214                #pat => { &(#expr) },
1215                _ => unreachable!("sql_forge!: validator arm mismatch"),
1216            }
1217        }
1218    }
1219}
1220
1221fn wrap_params_source_for_match_arm(
1222    params: &mut ParamsSource,
1223    match_expr: &Expr,
1224    pat: &Pat,
1225    guard: Option<&Expr>,
1226) {
1227    match params {
1228        ParamsSource::None => {}
1229        ParamsSource::Map(entries) => {
1230            for entry in entries {
1231                entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1232            }
1233        }
1234        ParamsSource::Struct(expr) => {
1235            **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1236        }
1237    }
1238}
1239
1240fn wrap_section_case_matrix_for_match_arm(
1241    case_matrix: &mut [Vec<SectionFragment>],
1242    match_expr: &Expr,
1243    pat: &Pat,
1244    guard: Option<&Expr>,
1245) {
1246    for row in case_matrix {
1247        for fragment in row {
1248            wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1249        }
1250    }
1251}
1252
1253// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1254// =============================================================================
1255// Section variant collection
1256// =============================================================================
1257
1258/// Collects all possible `SectionFragment` values per section index. Each
1259/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1260/// listing every variant that section can take across all match arms.
1261/// Used for full validation (cycling strategy over all variants).
1262fn collect_section_variants(
1263    value: SectionValue,
1264    width: usize,
1265) -> Result<Vec<Vec<SectionFragment>>, String> {
1266    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1267}
1268
1269fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1270    match strip_expr(expr) {
1271        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1272            let name = path.path.segments[0].ident.to_string();
1273            name.strip_prefix("__enhanced_result_flag_")
1274                .map(|v| v.to_string())
1275        }
1276        _ => None,
1277    }
1278}
1279
1280fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1281    match pat {
1282        Pat::Lit(expr_lit) => match &expr_lit.lit {
1283            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1284            _ => None,
1285        },
1286        Pat::Wild(_) => Some(true),
1287        _ => None,
1288    }
1289}
1290
1291/// Like `collect_section_variants`, but filters `match` arms by the active
1292/// result key when the match expression is a `{>key}` flag. When building the
1293/// query for a specific key, only the matching arm (true/false) is included;
1294/// arms with guards or non-flag expressions include all variants as usual.
1295fn collect_section_variants_for_result(
1296    value: SectionValue,
1297    width: usize,
1298    active_key: Option<&str>,
1299) -> Result<Vec<Vec<SectionFragment>>, String> {
1300    transpose_section_case_matrix(
1301        collect_section_case_matrix(value, width, active_key)?,
1302        width,
1303    )
1304}
1305
1306// =============================================================================
1307// Parameter binding generation
1308// =============================================================================
1309
1310/// Generates `let` bindings for all parameters used in the SQL or section
1311/// fragments. Returns a map of param-name → local ident for later reference,
1312/// and a list of `let` statements.
1313fn build_param_bindings(
1314    params: &ParamsSource,
1315    used_param_names: &[String],
1316    prefix: &str,
1317    for_validator: bool,
1318    enforce_usage_check: bool,
1319) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1320    let mut declared_params = HashMap::<String, syn::Ident>::new();
1321    let mut bindings = Vec::<TokenStream2>::new();
1322
1323    match params {
1324        ParamsSource::None => {}
1325        ParamsSource::Map(entries) => {
1326            for entry in entries {
1327                let key = entry.name.to_string();
1328                if declared_params.contains_key(&key) {
1329                    return Err(syn::Error::new(
1330                        entry.name.span(),
1331                        "sql_forge!: duplicated parameter mapping",
1332                    )
1333                    .to_compile_error()
1334                    .into());
1335                }
1336                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1337                    return Err(syn::Error::new(
1338                        entry.name.span(),
1339                        format!(
1340                            "sql_forge!: parameter :{} is unused in the SQL template",
1341                            key,
1342                        ),
1343                    )
1344                    .to_compile_error()
1345                    .into());
1346                }
1347                let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1348                let expr = &entry.expr;
1349                if for_validator {
1350                    bindings.push(quote! {
1351                        let #local_ident = &(#expr);
1352                    });
1353                } else {
1354                    bindings.push(quote! {
1355                        let #local_ident = #expr;
1356                    });
1357                }
1358                declared_params.insert(key, local_ident);
1359            }
1360        }
1361        ParamsSource::Struct(expr) => {
1362            let source_ident = format_ident!("__enhanced_source_{}", prefix);
1363            bindings.push(quote! {
1364                let #source_ident = &(#expr);
1365            });
1366            for name in used_param_names {
1367                let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1368                let field_ident = format_ident!("{}", name);
1369                if for_validator {
1370                    bindings.push(quote! {
1371                        let #local_ident = &#source_ident.#field_ident;
1372                    });
1373                } else {
1374                    bindings.push(quote! {
1375                        let #local_ident = #source_ident.#field_ident;
1376                    });
1377                }
1378                declared_params.insert(name.to_string(), local_ident);
1379            }
1380        }
1381    }
1382
1383    Ok((declared_params, bindings))
1384}
1385
1386struct ValidatorRenderContext<'a> {
1387    local_params: &'a HashMap<String, syn::Ident>,
1388    top_level_params: &'a HashMap<String, syn::Ident>,
1389    allow_top_level_fallback: bool,
1390    use_dollar_params: bool,
1391    sql_span: Span,
1392    list_count: usize,
1393}
1394
1395/// Builds the placeholders SQL string and argument list for the compile-time
1396/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1397/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1398/// corresponding value expression is collected into the args list.
1399fn render_validator_args(
1400    sql: &str,
1401    param_offset: &mut usize,
1402    arg_index: &mut usize,
1403    context: &ValidatorRenderContext<'_>,
1404) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1405    let (rendered_sql, occurrences) = render_validator_text(
1406        sql,
1407        context.use_dollar_params,
1408        param_offset,
1409        context.list_count,
1410    );
1411
1412    let mut setup = Vec::<TokenStream2>::new();
1413    let mut args = Vec::<TokenStream2>::new();
1414
1415    for (name, is_list) in occurrences {
1416        let local_ident = if context.allow_top_level_fallback {
1417            context
1418                .local_params
1419                .get(&name)
1420                .or_else(|| context.top_level_params.get(&name))
1421        } else {
1422            context.local_params.get(&name)
1423        };
1424
1425        let Some(local_ident) = local_ident else {
1426            return Err(syn::Error::new(
1427                context.sql_span,
1428                format!("sql_forge!: parameter :{} has no mapping", name),
1429            )
1430            .to_compile_error()
1431            .into());
1432        };
1433
1434        if is_list {
1435            for _ in 0..context.list_count {
1436                let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1437                *arg_index += 1;
1438                if context.use_dollar_params {
1439                    setup.push(quote! {
1440                        let #value_ident = sql_forge::sql_forge_validator_value(
1441                            (#local_ident)
1442                                .as_slice()
1443                                .first()
1444                                .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1445                        );
1446                    });
1447                } else {
1448                    setup.push(quote! {
1449                        let #value_ident = (#local_ident)
1450                            .as_slice()
1451                            .first()
1452                            .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1453                    });
1454                }
1455                args.push(quote! { #value_ident });
1456            }
1457        } else {
1458            let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1459            *arg_index += 1;
1460            if context.use_dollar_params {
1461                setup.push(quote! {
1462                    let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1463                });
1464            } else {
1465                setup.push(quote! {
1466                    let #value_ident = #local_ident;
1467                });
1468            }
1469            args.push(quote! { #value_ident });
1470        }
1471    }
1472
1473    Ok((rendered_sql, setup, args))
1474}
1475
1476// =============================================================================
1477// Runtime code generation (QueryBuilder-based)
1478// =============================================================================
1479
1480/// Generates the `push()` / `push_bind()` calls for a single section fragment
1481/// at runtime using `sqlx::QueryBuilder`.
1482fn render_runtime_fragment(
1483    fragment: &SectionFragment,
1484    local_params: &HashMap<String, syn::Ident>,
1485) -> Result<TokenStream2, TokenStream> {
1486    let mut steps = Vec::<TokenStream2>::new();
1487
1488    for part in parse_text_parts(&fragment.sql) {
1489        match part {
1490            TextPart::Lit(lit) => {
1491                let lit_str = LitStr::new(&lit, fragment.span);
1492                steps.push(quote! { __builder.push(#lit_str); });
1493            }
1494            TextPart::Param { name, is_list } => {
1495                let Some(local_ident) = local_params.get(&name) else {
1496                    return Err(syn::Error::new(
1497                        fragment.span,
1498                        format!("sql_forge!: parameter :{} has no mapping", name),
1499                    )
1500                    .to_compile_error()
1501                    .into());
1502                };
1503
1504                if is_list {
1505                    steps.push(quote! {
1506                        let __enhanced_values = #local_ident;
1507                        let mut __separated = __builder.separated(", ");
1508                        for __value in __enhanced_values {
1509                            __separated.push_bind(__value);
1510                        }
1511                    });
1512                } else {
1513                    steps.push(quote! {
1514                        __builder.push_bind(#local_ident);
1515                    });
1516                }
1517            }
1518        }
1519    }
1520
1521    Ok(quote! { #( #steps )* })
1522}
1523
1524fn build_section_runtime_action(
1525    value: &SectionValue,
1526    section_idx: usize,
1527    prefix: &str,
1528) -> Result<TokenStream2, TokenStream> {
1529    match value {
1530        SectionValue::Single(fragment) => {
1531            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1532            let (local_params, bindings) =
1533                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1534            let body = render_runtime_fragment(fragment, &local_params)?;
1535            Ok(quote! {{ #( #bindings )* #body }})
1536        }
1537        SectionValue::Grouped(fragments) => build_section_runtime_action(
1538            &fragments[section_idx],
1539            0,
1540            &format!("{}_grouped_{}", prefix, section_idx),
1541        ),
1542        SectionValue::Match { expr, arms } => {
1543            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1544                .iter()
1545                .enumerate()
1546                .map(|(arm_idx, arm)| {
1547                    let pat = &arm.pat;
1548                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1549                    let body = build_section_runtime_action(
1550                        &arm.value,
1551                        section_idx,
1552                        &format!("{}_{}", prefix, arm_idx),
1553                    )?;
1554                    Ok::<TokenStream2, TokenStream>(quote! { #pat #guard_tokens => #body })
1555                })
1556                .collect();
1557            let arm_tokens = arm_tokens?;
1558            Ok(quote! {
1559                match #expr {
1560                    #( #arm_tokens ),*
1561                }
1562            })
1563        }
1564    }
1565}
1566
1567fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1568    let mut names = Vec::new();
1569    let mut seen = HashSet::<String>::new();
1570
1571    for segment in segments {
1572        match segment {
1573            Segment::Text(text) => {
1574                for name in collect_used_param_names_in_sql(text) {
1575                    if seen.insert(name.clone()) {
1576                        names.push(name);
1577                    }
1578                }
1579            }
1580            Segment::Batch { parts } => {
1581                for part in parts {
1582                    if let TextPart::Param { name, .. } = part {
1583                        if seen.insert(name.clone()) {
1584                            names.push(name.clone());
1585                        }
1586                    }
1587                }
1588            }
1589            _ => {}
1590        }
1591    }
1592
1593    names
1594}
1595
1596fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1597    let mut names = Vec::new();
1598    let mut seen = HashSet::<String>::new();
1599    for part in parse_text_parts(sql) {
1600        if let TextPart::Param { name, .. } = part {
1601            if seen.insert(name.to_string()) {
1602                names.push(name);
1603            }
1604        }
1605    }
1606    names
1607}
1608
1609/// Builds a parameterized SQL query with compile-time type-checking and a
1610/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1611///
1612/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1613/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1614///
1615/// # Syntax
1616///
1617/// ```text
1618/// sql_forge!(
1619///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1620///     [Model,]     // optional result spec
1621///     SQL,         // string literal
1622///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1623///     [(sections),]// optional: ( #name = ..., ... )
1624///     [..batch]    // optional: batch source expression used by {( ... )}
1625/// )
1626/// ```
1627///
1628/// `Model` has three forms:
1629/// - omitted: execute-only query; only `.execute(...)` is available
1630/// - `Type` or `scalar Type`: a single result query
1631/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1632///
1633/// The trailing parameter source, section map, and batch source are optional.
1634/// The batch source may appear alongside the others as a single `..expr` argument.
1635///
1636/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1637/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1638/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1639/// The env var takes priority over Cargo.toml metadata.
1640///
1641/// # Parameters
1642///
1643/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1644/// is replaced by a `push_bind` call; at compile time it becomes a
1645/// database-specific placeholder: `?` for MySQL and SQLite, and `$1`, `$2`, ...
1646/// for Postgres.
1647///
1648/// **Inline map** – bind individual expressions:
1649/// ```rust,ignore
1650/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1651/// ```
1652///
1653/// **Struct source** – field names are matched to `:name` placeholders automatically:
1654/// ```rust,ignore
1655/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1656/// ```
1657///
1658/// # Sections (`{#name}`)
1659///
1660/// Sections are runtime SQL slots; each section's variants are validated at
1661/// compile time via `query_as!` / `query_scalar!`, though not every combination
1662/// of variants across sections is checked. The section map is a second parenthesised
1663/// argument starting with `#`:
1664///
1665/// ```rust,ignore
1666/// sql_forge!(
1667///     User,
1668///     "SELECT * FROM users {#join_org}",
1669///     (
1670///         #join_org = match include_org {
1671///             true  => " JOIN organisations o ON o.id = users.org_id ",
1672///             false => "",
1673///         }
1674///     )
1675/// )
1676/// ```
1677///
1678/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1679///
1680/// ```rust,ignore
1681/// sql_forge!(
1682///     User,
1683///     "SELECT * FROM users {#filter}",
1684///     (
1685///         #filter = (
1686///             " WHERE id <= :max_id AND status = :status ",
1687///             ( :max_id = max_id, :status = "active" ),
1688///         )
1689///     )
1690/// )
1691/// ```
1692///
1693/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1694/// returning a tuple of the same width:
1695///
1696/// ```rust,ignore
1697/// sql_forge!(
1698///     User,
1699///     "SELECT * FROM users {#join_org} {#filter_org}",
1700///     (
1701///         #(join_org, filter_org) = match include_org {
1702///             true  => (
1703///                 " JOIN organisations o ON o.id = users.org_id ",
1704///                 (
1705///                     " AND o.active = :active ",
1706///                     ( :active = true ),
1707///                 ),
1708///             ),
1709///             false => ("", ""),
1710///         }
1711///     )
1712/// )
1713/// ```
1714///
1715/// Grouped section items may themselves use nested `match` expressions. Those
1716/// nested matches use smart cycling within the arm rather than a cartesian
1717/// product. For example, if one grouped arm returns a fixed first item plus two
1718/// nested binary matches for the second and third items, that arm contributes
1719/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1720/// `(1, 0)`, `(1, 1)` combinations.
1721///
1722/// # `IN (...)` with list parameters
1723///
1724/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1725/// values:
1726///
1727/// ```rust,ignore
1728/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1729/// ```
1730///
1731/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1732/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1733///
1734/// ```rust,ignore
1735/// sql_forge!(
1736///     User,
1737///     "SELECT id, name FROM users WHERE {#filter}",
1738///     (
1739///         #filter = match ids.is_empty() {
1740///             true  => "1 = 0",
1741///             false => ("id IN (:ids[])", ( :ids = ids )),
1742///         }
1743///     )
1744/// )
1745/// ```
1746///
1747/// # Batch inserts (`{( ... )}`)
1748///
1749/// A batch section `{( ... )}` repeats its content for each item in an iterable
1750/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1751/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1752/// sections.
1753///
1754/// ## Struct batch
1755///
1756/// ```rust,ignore
1757/// struct BatchItem { name: String, price: i64 }
1758///
1759/// let items = vec![
1760///     BatchItem { name: "A".into(), price: 100 },
1761///     BatchItem { name: "B".into(), price: 200 },
1762/// ];
1763///
1764/// sql_forge!(
1765///     "INSERT INTO products (name, price, stock, category)
1766///      VALUES {(:name, :price, 10, 'Batch')}",
1767///     ..items
1768/// )
1769/// .execute(&pool)
1770/// .await?;
1771/// ```
1772///
1773/// For compile-time checking, the validator expands the batch to 3 fake copies
1774/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`).
1775/// At runtime the iterable drives the actual number of rows.
1776///
1777/// # Scalar output
1778///
1779/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
1780/// `query_scalar!` for validation and `build_query_scalar` for execution.
1781///
1782/// # Multiple results
1783///
1784/// A result map produces a `SqlForgeQueryGroup` with one query per key.
1785/// Each key can be a struct or a primitive (used as a scalar):
1786///
1787/// ```rust,ignore
1788/// sql_forge!(
1789///     (
1790///         >count = i64,
1791///         >items = Item,
1792///     ),
1793///     "SELECT {#fields} FROM items WHERE category_id = :cat",
1794///     ( :cat = category_id ),
1795///     (
1796///         #fields = match {>count} {           // {>key} is true when building
1797///             true  => "COUNT(*) AS total",    // the query for that model/result
1798///             false => "id, name, price",      // key and false otherwise
1799///         }
1800///     )
1801/// )
1802/// ```
1803///
1804/// The generated struct has one field per key (`group.count`, `group.items`),
1805/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
1806/// executor method (`fetch_one`, `fetch_all`, etc.).
1807///
1808/// # Execute-only (no model)
1809///
1810/// When the model type is omitted, the macro produces a value implementing
1811/// `SqlForgeQueryExecute`. Only `.execute(executor)`
1812/// is available and there is no return type to deserialize into. This is useful
1813/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
1814///
1815/// ```rust,ignore
1816/// sql_forge!(
1817///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
1818///     ( :id = 42i64 ),
1819/// )
1820/// .execute(&pool)
1821/// .await?;
1822/// ```
1823///
1824/// Sections and struct parameter sources work the same way as in model-backed queries:
1825///
1826/// ```rust,ignore
1827/// sql_forge!(
1828///     "UPDATE products SET price = :new_price {#filter}",
1829///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
1830/// )
1831/// .execute(&pool)
1832/// .await?;
1833/// ```
1834#[proc_macro]
1835#[allow(clippy::too_many_lines)]
1836pub fn sql_forge(input: TokenStream) -> TokenStream {
1837    // ---- Phase 1: Parse the macro input into structured data ----
1838    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1839    let SqlForgeInput {
1840        db,
1841        result,
1842        force_scalar,
1843        sql,
1844        params,
1845        sections,
1846        batch,
1847    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1848        Ok(v) => v,
1849        Err(err) => return err.to_compile_error().into(),
1850    };
1851
1852    // ---- Phase 2: Resolve database type (from macro arg or Cargo.toml) ----
1853    let db = match db {
1854        Some(db) => db,
1855        None => match resolve_db_from_env() {
1856            Ok(db) => db,
1857            Err(msg) => {
1858                return syn::Error::new(Span::call_site(), msg)
1859                    .to_compile_error()
1860                    .into();
1861            }
1862        },
1863    };
1864
1865    let use_dollar_params = uses_dollar_params(&db);
1866    let is_sqlite = if let syn::Type::Path(type_path) = &db {
1867        type_path
1868            .path
1869            .segments
1870            .last()
1871            .is_some_and(|s| s.ident == "Sqlite")
1872    } else {
1873        false
1874    };
1875    let list_count: usize = if is_sqlite { 1 } else { 3 };
1876
1877    // ---- Phase 3: Build result case definitions ----
1878    // Each result case is (optional_key, model_type, optional_scalar_type).
1879    // Scalar type is set for primitives and `scalar`-marked types.
1880    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1881        ResultSpec::None => {
1882            vec![(None, None, None)]
1883        }
1884        ResultSpec::Single(ref model) => {
1885            let model_ty = (**model).clone();
1886            let scalar = if force_scalar {
1887                Some(model_ty.clone())
1888            } else {
1889                scalar_output_type(model.as_ref()).cloned()
1890            };
1891            vec![(None, Some(model_ty), scalar)]
1892        }
1893        ResultSpec::Group(ref cases) => {
1894            if force_scalar {
1895                return syn::Error::new(
1896                    Span::call_site(),
1897                    "sql_forge!: scalar mode is not supported for grouped result maps",
1898                )
1899                .to_compile_error()
1900                .into();
1901            }
1902
1903            let mut out = Vec::new();
1904            let mut seen = HashSet::new();
1905            for case in cases {
1906                let key = case.name.to_string();
1907                if !seen.insert(key.clone()) {
1908                    return syn::Error::new(
1909                        case.name.span(),
1910                        "sql_forge!: duplicated key in result map",
1911                    )
1912                    .to_compile_error()
1913                    .into();
1914                }
1915
1916                let model = case.model.clone();
1917                let scalar = if case.force_scalar {
1918                    Some(model.clone())
1919                } else {
1920                    scalar_output_type(&case.model).cloned()
1921                };
1922                out.push((Some(key), Some(model), scalar));
1923            }
1924            out
1925        }
1926    };
1927    let group_result_keys: Vec<String> = result_cases
1928        .iter()
1929        .filter_map(|(key, _, _)| key.as_ref().cloned())
1930        .collect();
1931    let is_grouped_result = !group_result_keys.is_empty();
1932    let sql_span = sql.span();
1933
1934    // ---- Phase 4: Parse SQL into segments (text + {#section} slots) ----
1935    let segments = match sql.into_segments() {
1936        Ok(segments) => segments,
1937        Err(msg) => {
1938            return syn::Error::new(sql_span, msg).to_compile_error().into();
1939        }
1940    };
1941
1942    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
1943    match (&batch, has_batch_segment) {
1944        (None, true) => {
1945            return syn::Error::new(
1946                sql_span,
1947                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
1948                 was provided"
1949            )
1950            .to_compile_error()
1951            .into();
1952        }
1953        (Some(_), false) => {
1954            return syn::Error::new(
1955                sql_span,
1956                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
1957                 batch section",
1958            )
1959            .to_compile_error()
1960            .into();
1961        }
1962        _ => {}
1963    }
1964
1965    let used_param_names = collect_used_param_names(&segments);
1966
1967    // Batch-only params come from batch items, not the top-level params map.
1968    // They must be excluded from the usage check so that a param like :category
1969    // that appears only inside {( ... )} is flagged as unused when given in the
1970    // params map, as it would never be read from there at runtime.
1971    let batch_param_names: std::collections::HashSet<String> = segments
1972        .iter()
1973        .filter_map(|s| {
1974            if let Segment::Batch { parts } = s {
1975                Some(parts.iter().filter_map(|p| {
1976                    if let TextPart::Param { name, .. } = p {
1977                        Some(name.clone())
1978                    } else {
1979                        None
1980                    }
1981                }))
1982            } else {
1983                None
1984            }
1985        })
1986        .flatten()
1987        .collect();
1988    let top_level_used_names: Vec<String> = used_param_names
1989        .iter()
1990        .filter(|n| !batch_param_names.contains(*n))
1991        .cloned()
1992        .collect();
1993
1994    // ---- Phase 5: Build parameter bindings for the top-level params ----
1995    let (declared_params, validator_param_bindings) =
1996        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
1997            Ok(v) => v,
1998            Err(err) => return err,
1999        };
2000
2001    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2002
2003    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
2004    for assign in &sections {
2005        let SectionAssign { names, value } = assign;
2006
2007        // Build runtime actions first, while `value` is still available by reference.
2008        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2009        for (section_idx, name_ident) in names.iter().enumerate() {
2010            let name = name_ident.to_string();
2011            if runtime_section_actions.contains_key(&name) {
2012                return syn::Error::new(
2013                    name_ident.span(),
2014                    "sql_forge!: duplicated section mapping",
2015                )
2016                .to_compile_error()
2017                .into();
2018            }
2019            let action = match build_section_runtime_action(
2020                value,
2021                section_idx,
2022                &format!("section_{}", name),
2023            ) {
2024                Ok(action) => action,
2025                Err(err) => return err,
2026            };
2027            named_actions.push((name, action));
2028        }
2029
2030        // Consume `value` here so invalid grouped/nested section structures fail early.
2031        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2032            return syn::Error::new(names[0].span(), msg)
2033                .to_compile_error()
2034                .into();
2035        }
2036
2037        for (name, action) in named_actions {
2038            runtime_section_actions.insert(name, action);
2039        }
2040    }
2041
2042    let sql_section_names: std::collections::HashSet<&str> = segments
2043        .iter()
2044        .filter_map(|seg| {
2045            if let Segment::Section { name } = seg {
2046                Some(name.as_str())
2047            } else {
2048                None
2049            }
2050        })
2051        .collect();
2052    for name in runtime_section_actions.keys() {
2053        if !sql_section_names.contains(name.as_str()) {
2054            return syn::Error::new(
2055                sql_span,
2056                format!(
2057                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2058                    name, name,
2059                ),
2060            )
2061            .to_compile_error()
2062            .into();
2063        }
2064    }
2065
2066    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
2067    let mut generated_query_defs = Vec::<TokenStream2>::new();
2068    let mut generated_query_values = Vec::<TokenStream2>::new();
2069    let mut group_field_defs = Vec::<TokenStream2>::new();
2070    let mut group_method_defs = Vec::<TokenStream2>::new();
2071    let mut group_field_idents = Vec::<syn::Ident>::new();
2072    let mut group_field_tys = Vec::<TokenStream2>::new();
2073    let mut group_trait_impls = Vec::<TokenStream2>::new();
2074
2075    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2076
2077    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2078        let suffix = result_key.as_deref().unwrap_or("single");
2079        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2080        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2081
2082        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2083
2084        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2085        for assign in &sections {
2086            let SectionAssign { names, value } = assign;
2087            let variants_by_section = match collect_section_variants_for_result(
2088                value.clone(),
2089                names.len(),
2090                result_key.as_deref(),
2091            ) {
2092                Ok(v) => v,
2093                Err(msg) => {
2094                    return syn::Error::new(names[0].span(), msg)
2095                        .to_compile_error()
2096                        .into();
2097                }
2098            };
2099
2100            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2101                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2102            }
2103        }
2104
2105        let mut nmax = 1usize;
2106        for segment in &segments {
2107            if let Segment::Section { name } = segment {
2108                if let Some(variants) = section_variants_for_validation.get(name) {
2109                    if variants.is_empty() {
2110                        return syn::Error::new(
2111                            sql_span,
2112                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
2113                        )
2114                        .to_compile_error()
2115                        .into();
2116                    }
2117                    nmax = nmax.max(variants.len());
2118                } else {
2119                    return syn::Error::new(
2120                        sql_span,
2121                        format!("sql_forge!: section {{#{}}} has no mapping", name),
2122                    )
2123                    .to_compile_error()
2124                    .into();
2125                }
2126            }
2127        }
2128
2129        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2130        for case_idx in 0..nmax {
2131            let mut sql_case = String::new();
2132            let mut case_setup = Vec::<TokenStream2>::new();
2133            let mut case_args = Vec::<TokenStream2>::new();
2134            let mut param_offset = 0usize;
2135            let mut arg_index = 0usize;
2136            let empty_params = HashMap::<String, syn::Ident>::new();
2137            let root_validator_context = ValidatorRenderContext {
2138                local_params: &empty_params,
2139                top_level_params: &declared_params,
2140                allow_top_level_fallback: true,
2141                use_dollar_params,
2142                sql_span,
2143                list_count,
2144            };
2145
2146            for segment in &segments {
2147                match segment {
2148                    Segment::Text(text) => {
2149                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2150                            text,
2151                            &mut param_offset,
2152                            &mut arg_index,
2153                            &root_validator_context,
2154                        ) {
2155                            Ok(value) => value,
2156                            Err(err) => return err,
2157                        };
2158                        sql_case.push_str(&chunk_sql);
2159                        case_setup.extend(chunk_setup);
2160                        case_args.extend(chunk_args);
2161                    }
2162                    Segment::Section { name } => {
2163                        let Some(variants) = section_variants_for_validation.get(name) else {
2164                            return syn::Error::new(
2165                                sql_span,
2166                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2167                            )
2168                            .to_compile_error()
2169                            .into();
2170                        };
2171
2172                        let fragment = &variants[case_idx % variants.len()];
2173                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2174                        let (local_params, bindings) = match build_param_bindings(
2175                            &fragment.params,
2176                            &used_param_names,
2177                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2178                            true,
2179                            true,
2180                        ) {
2181                            Ok(value) => value,
2182                            Err(err) => return err,
2183                        };
2184                        let section_validator_context = ValidatorRenderContext {
2185                            local_params: &local_params,
2186                            top_level_params: &declared_params,
2187                            allow_top_level_fallback: false,
2188                            use_dollar_params,
2189                            sql_span: fragment.span,
2190                            list_count,
2191                        };
2192                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2193                            &fragment.sql,
2194                            &mut param_offset,
2195                            &mut arg_index,
2196                            &section_validator_context,
2197                        ) {
2198                            Ok(value) => value,
2199                            Err(err) => return err,
2200                        };
2201                        sql_case.push_str(&chunk_sql);
2202                        case_setup.extend(bindings);
2203                        case_setup.extend(chunk_setup);
2204                        case_args.extend(chunk_args);
2205                    }
2206                    Segment::Batch { parts } => {
2207                        let mut first = true;
2208                        for _ in 0..list_count {
2209                            let sep = if first { "" } else { ", " };
2210                            first = false;
2211                            sql_case.push_str(sep);
2212                            for tp in parts {
2213                                match tp {
2214                                    TextPart::Lit(lit) => sql_case.push_str(lit),
2215                                    TextPart::Param { name, .. } => {
2216                                        if let Some(batch_expr) = &batch {
2217                                            let field_ident = format_ident!("{}", name);
2218                                            if use_dollar_params {
2219                                                param_offset += 1;
2220                                                write!(sql_case, "${}", param_offset).unwrap();
2221                                            } else {
2222                                                sql_case.push('?');
2223                                            }
2224                                            case_args.push(quote! { #batch_expr[0].#field_ident });
2225                                        } else if use_dollar_params {
2226                                            param_offset += 1;
2227                                            write!(sql_case, "${}", param_offset).unwrap();
2228                                        } else {
2229                                            sql_case.push('?');
2230                                        }
2231                                    }
2232                                }
2233                            }
2234                        }
2235                    }
2236                }
2237            }
2238
2239            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2240        }
2241
2242        let mut validator_invocations = Vec::<TokenStream2>::new();
2243        for (sql_lit, case_setup, args) in &validator_cases {
2244            if model_opt.is_none() {
2245                if args.is_empty() {
2246                    validator_invocations.push(quote! {
2247                        {
2248                            #( #case_setup )*
2249                            let _ = sqlx::query_scalar!(
2250                                #sql_lit,
2251                            );
2252                        }
2253                    });
2254                } else {
2255                    validator_invocations.push(quote! {
2256                        {
2257                            #( #case_setup )*
2258                            let _ = sqlx::query_scalar!(
2259                                #sql_lit,
2260                                #( #args ),*
2261                            );
2262                        }
2263                    });
2264                }
2265            } else if let Some(scalar_ty) = scalar_model_ty {
2266                if args.is_empty() {
2267                    validator_invocations.push(quote! {
2268                        {
2269                            #( #case_setup )*
2270                            let _ = sqlx::query_scalar!(
2271                                #sql_lit,
2272                            );
2273                        }
2274                    });
2275                } else {
2276                    validator_invocations.push(quote! {
2277                        {
2278                            #( #case_setup )*
2279                            let _ = sqlx::query_scalar!(
2280                                #sql_lit,
2281                                #( #args ),*
2282                            );
2283                        }
2284                    });
2285                }
2286                let _ = scalar_ty;
2287            } else if args.is_empty() {
2288                validator_invocations.push(quote! {
2289                    {
2290                        #( #case_setup )*
2291                        let _ = sqlx::query_as!(
2292                            __EnhancedModel,
2293                            #sql_lit,
2294                        );
2295                    }
2296                });
2297            } else {
2298                validator_invocations.push(quote! {
2299                    {
2300                        #( #case_setup )*
2301                        let _ = sqlx::query_as!(
2302                            __EnhancedModel,
2303                            #sql_lit,
2304                            #( #args ),*
2305                        );
2306                    }
2307                });
2308            }
2309        }
2310
2311        let model_alias = if let Some(model) = model_opt {
2312            if scalar_model_ty.is_none() {
2313                quote! { type __EnhancedModel = #model; }
2314            } else {
2315                quote! {}
2316            }
2317        } else {
2318            quote! {}
2319        };
2320        grouped_validator_invocations.push(quote! {
2321            {
2322                #( #flag_bindings )*
2323                #model_alias
2324                #( #validator_invocations )*
2325            }
2326        });
2327
2328        let (runtime_declared_params, runtime_param_bindings) =
2329            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2330                Ok(v) => v,
2331                Err(err) => return err,
2332            };
2333
2334        let mut runtime_steps = Vec::<TokenStream2>::new();
2335        for (seg_idx, segment) in segments.iter().enumerate() {
2336            match segment {
2337                Segment::Text(text) => {
2338                    for part in parse_text_parts(text) {
2339                        match part {
2340                            TextPart::Lit(lit) => {
2341                                let lit = sanitize_runtime_sql_text(&lit);
2342                                let lit_str = LitStr::new(&lit, sql_span);
2343                                runtime_steps.push(quote! {
2344                                    __builder.push(#lit_str);
2345                                });
2346                            }
2347                            TextPart::Param { name, is_list } => {
2348                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2349                                    return syn::Error::new(
2350                                        sql_span,
2351                                        format!("sql_forge!: parameter :{} has no mapping", name),
2352                                    )
2353                                    .to_compile_error()
2354                                    .into();
2355                                };
2356
2357                                if is_list {
2358                                    runtime_steps.push(quote! {
2359                                        let __enhanced_values = #local_ident;
2360                                        let mut __separated = __builder.separated(", ");
2361                                        for __value in __enhanced_values {
2362                                            __separated.push_bind(__value);
2363                                        }
2364                                    });
2365                                } else {
2366                                    runtime_steps.push(quote! {
2367                                        __builder.push_bind(#local_ident);
2368                                    });
2369                                }
2370                            }
2371                        }
2372                    }
2373                }
2374                Segment::Section { name } => {
2375                    let Some(section_action) = runtime_section_actions.get(name) else {
2376                        let _ = seg_idx;
2377                        return syn::Error::new(
2378                            sql_span,
2379                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2380                        )
2381                        .to_compile_error()
2382                        .into();
2383                    };
2384                    runtime_steps.push(quote! {
2385                        #section_action
2386                    });
2387                }
2388                Segment::Batch { parts } => {
2389                    if let Some(batch_expr) = &batch {
2390                        let mut body = Vec::<TokenStream2>::new();
2391                        for part in parts {
2392                            match part {
2393                                TextPart::Lit(lit) => {
2394                                    let lit_str = LitStr::new(lit, sql_span);
2395                                    body.push(quote! {
2396                                        __builder.push(#lit_str);
2397                                    });
2398                                }
2399                                TextPart::Param { name, .. } => {
2400                                    let field_ident = format_ident!("{}", name);
2401                                    body.push(quote! {
2402                                        __builder.push_bind(__item.#field_ident);
2403                                    });
2404                                }
2405                            }
2406                        }
2407                        runtime_steps.push(quote! {
2408                            {
2409                                let mut __first = true;
2410                                for __item in #batch_expr {
2411                                    if !__first {
2412                                        __builder.push(", ");
2413                                    }
2414                                    __first = false;
2415                                    #( #body )*
2416                                }
2417                            }
2418                        });
2419                    }
2420                }
2421            }
2422        }
2423
2424        let exec_methods = if model_opt.is_none() {
2425            quote! {
2426                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2427                where
2428                    E: sqlx::Executor<'e, Database = #db>,
2429                {
2430                    self.inner.build().execute(executor).await
2431                }
2432            }
2433        } else if let Some(scalar_ty) = scalar_model_ty {
2434            quote! {
2435                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2436                where
2437                    E: sqlx::Executor<'e, Database = #db>,
2438                {
2439                    self.inner
2440                        .build_query_scalar::<#scalar_ty>()
2441                        .fetch_all(executor)
2442                        .await
2443                }
2444
2445                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2446                where
2447                    E: sqlx::Executor<'e, Database = #db>,
2448                {
2449                    self.inner
2450                        .build_query_scalar::<#scalar_ty>()
2451                        .fetch_one(executor)
2452                        .await
2453                }
2454
2455                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2456                where
2457                    E: sqlx::Executor<'e, Database = #db>,
2458                {
2459                    self.inner
2460                        .build_query_scalar::<#scalar_ty>()
2461                        .fetch_optional(executor)
2462                        .await
2463                }
2464
2465                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2466                where
2467                    E: sqlx::Executor<'e, Database = #db>,
2468                {
2469                    self.inner.build().execute(executor).await
2470                }
2471            }
2472        } else {
2473            let model = model_opt.as_ref().unwrap();
2474            quote! {
2475                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2476                where
2477                    E: sqlx::Executor<'e, Database = #db>,
2478                {
2479                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2480                }
2481
2482                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2483                where
2484                    E: sqlx::Executor<'e, Database = #db>,
2485                {
2486                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2487                }
2488
2489                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2490                where
2491                    E: sqlx::Executor<'e, Database = #db>,
2492                {
2493                    self.inner
2494                        .build_query_as::<#model>()
2495                        .fetch_optional(executor)
2496                        .await
2497                }
2498
2499                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2500                where
2501                    E: sqlx::Executor<'e, Database = #db>,
2502                {
2503                    self.inner.build().execute(executor).await
2504                }
2505            }
2506        };
2507
2508        let final_type: TokenStream2 = if let Some(model) = model_opt {
2509            if let Some(scalar_ty) = scalar_model_ty {
2510                quote! { #scalar_ty }
2511            } else {
2512                quote! { #model }
2513            }
2514        } else {
2515            quote! {}
2516        };
2517        let trait_impl = if model_opt.is_none() {
2518            quote! {
2519                impl<'args> sql_forge::SqlForgeQueryExecute
2520                    for #query_ident<'args>
2521                {
2522                    type Db = #db;
2523
2524                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2525                    where
2526                        Self: Sized + 'e,
2527                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2528                        #db: 'e,
2529                    {
2530                        #query_ident::execute(self, executor)
2531                    }
2532                }
2533            }
2534        } else {
2535            quote! {
2536                impl<'args> sql_forge::SqlForgeQuery<#final_type>
2537                    for #query_ident<'args>
2538                {
2539                    type Db = #db;
2540
2541                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2542                    where
2543                        Self: Sized + 'e,
2544                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2545                        #db: 'e,
2546                    {
2547                        #query_ident::fetch_all(self, executor)
2548                    }
2549
2550                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2551                    where
2552                        Self: Sized + 'e,
2553                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2554                        #db: 'e,
2555                    {
2556                        #query_ident::fetch_one(self, executor)
2557                    }
2558
2559                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2560                    where
2561                        Self: Sized + 'e,
2562                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2563                        #db: 'e,
2564                    {
2565                        #query_ident::fetch_optional(self, executor)
2566                    }
2567
2568                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2569                    where
2570                        Self: Sized + 'e,
2571                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2572                        #db: 'e,
2573                    {
2574                        #query_ident::execute(self, executor)
2575                    }
2576                }
2577            }
2578        };
2579
2580        generated_query_defs.push(quote! {
2581            struct #query_ident<'args> {
2582                inner: sqlx::QueryBuilder<'args, #db>,
2583            }
2584
2585            impl<'args> #query_ident<'args> {
2586                #exec_methods
2587            }
2588
2589            #trait_impl
2590        });
2591
2592        generated_query_values.push(quote! {
2593            #( #runtime_param_bindings )*
2594            #( #flag_bindings )*
2595            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2596            #( #runtime_steps )*
2597            let #query_value_ident = #query_ident { inner: __builder };
2598        });
2599
2600        if let Some(key) = result_key {
2601            let method_ident = format_ident!("{}", key);
2602            group_field_defs.push(quote! {
2603                #method_ident: #query_ident<'args>
2604            });
2605            group_field_tys.push(quote! { #query_ident<'args> });
2606            group_method_defs.push(quote! {
2607                pub fn #method_ident(self) -> #query_ident<'args> {
2608                    self.#method_ident
2609                }
2610            });
2611
2612            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2613            group_trait_impls.push(quote! {
2614                struct #key_ty_ident;
2615
2616                impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2617                    type Query = #query_ident<'args>;
2618
2619                    fn get(self, _: #key_ty_ident) -> Self::Query {
2620                        self.#method_ident
2621                    }
2622                }
2623            });
2624            group_field_idents.push(method_ident);
2625        }
2626    }
2627
2628    // ---- Phase 8: Emit the final token stream ----
2629    let validator_tokens = quote! {
2630        let _sql_forge_validator = || {
2631            #( #validator_param_bindings )*
2632            #( #grouped_validator_invocations )*
2633        };
2634    };
2635
2636    if !is_grouped_result {
2637        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2638        return quote! {
2639            {
2640                #validator_tokens
2641                #( #generated_query_defs )*
2642                #( #generated_query_values )*
2643                #single_query_value_ident
2644            }
2645        }
2646        .into();
2647    }
2648
2649    let group_field_inits: Vec<TokenStream2> = result_cases
2650        .iter()
2651        .filter_map(|(key, _, _)| key.as_ref())
2652        .map(|key| {
2653            let method_ident = format_ident!("{}", key);
2654            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2655            quote! { #method_ident: #query_value_ident }
2656        })
2657        .collect();
2658
2659    quote! {
2660        {
2661            #validator_tokens
2662
2663            #( #generated_query_defs )*
2664            #( #generated_query_values )*
2665
2666            struct __SqlForgeQueryGroup<'args> {
2667                #( #group_field_defs, )*
2668            }
2669
2670            impl<'args> __SqlForgeQueryGroup<'args> {
2671                #( #group_method_defs )*
2672
2673                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2674                    ( #( self.#group_field_idents ),* )
2675                }
2676            }
2677
2678            impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2679                type Db = #db;
2680            }
2681
2682            #( #group_trait_impls )*
2683
2684            __SqlForgeQueryGroup {
2685                #( #group_field_inits, )*
2686            }
2687        }
2688    }
2689    .into()
2690}
2691
2692/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2693/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2694///
2695/// ```rust,ignore
2696/// use sql_forge::db_type;
2697///
2698/// pub type AppDb = db_type!();
2699/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2700/// ```
2701///
2702/// Priority:
2703/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2704/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2705#[proc_macro]
2706pub fn db_type(input: TokenStream) -> TokenStream {
2707    if !input.is_empty() {
2708        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2709            .to_compile_error()
2710            .into();
2711    }
2712
2713    match resolve_db_from_env() {
2714        Ok(db) => quote! { #db }.into(),
2715        Err(msg) => syn::Error::new(Span::call_site(), msg)
2716            .to_compile_error()
2717            .into(),
2718    }
2719}
2720
2721/// Marks a single-value tuple struct as a transparent wrapper.
2722///
2723/// Equivalent to applying `#[derive(sqlx::Type)]` + `#[sqlx(transparent)]`
2724/// and additionally implements `SqlForgeValidatorValue` so that list
2725/// parameters (`:ids[]`) and other bindings validate correctly for
2726/// PostgreSQL databases (which require exact type matching in `query_as!`).
2727///
2728/// ```rust,ignore
2729/// #[derive(Debug, PartialEq, Eq)]
2730/// #[sql_forge_transparent]
2731/// struct UserId(pub i64);
2732/// ```
2733#[proc_macro_attribute]
2734pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2735    let input: ItemStruct = match syn::parse(item) {
2736        Ok(v) => v,
2737        Err(err) => return err.to_compile_error().into(),
2738    };
2739
2740    let struct_name = &input.ident;
2741    let inner_type = match &input.fields {
2742        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2743        _ => {
2744            return syn::Error::new(
2745                input.span(),
2746                "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2747            )
2748            .to_compile_error()
2749            .into();
2750        }
2751    };
2752
2753    let attrs = input.attrs;
2754    let generics = &input.generics;
2755    let vis = &input.vis;
2756    let struct_token = input.struct_token;
2757    let semi_token = input.semi_token;
2758    let fields = &input.fields;
2759
2760    let expanded = quote! {
2761        #( #attrs )*
2762        #[derive(sqlx::Type)]
2763        #[sqlx(transparent)]
2764        #vis #struct_token #struct_name #generics #fields #semi_token
2765
2766        impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2767            fn sql_forge_validator_value(&self) -> #inner_type {
2768                self.0.clone()
2769            }
2770        }
2771    };
2772
2773    expanded.into()
2774}