1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17 Stmt, Token, Type,
18};
19
20mod kw {
25 syn::custom_keyword!(scalar);
26}
27
28#[derive(Clone)]
30struct ParamAssign {
31 name: Ident,
32 expr: Expr,
33}
34
35impl Parse for ParamAssign {
36 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37 input.parse::<Token![:]>()?;
38 let name: Ident = input.parse()?;
39 input.parse::<Token![=]>()?;
40 let expr: Expr = input.parse()?;
41 Ok(Self { name, expr })
42 }
43}
44
45#[derive(Clone)]
47struct SectionFragment {
48 sql: String,
49 span: Span,
50 params: ParamsSource,
51}
52
53#[derive(Clone)]
55struct SectionMatchArm {
56 pat: Pat,
57 guard: Option<Expr>,
58 value: SectionValue,
59}
60
61#[derive(Clone)]
63enum SectionValue {
64 Single(SectionFragment),
66 Grouped(Vec<SectionValue>),
68 Match {
70 expr: Expr,
71 arms: Vec<SectionMatchArm>,
72 },
73}
74
75#[derive(Clone)]
77struct SectionAssign {
78 names: Vec<Ident>,
79 value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 input.parse::<Token![#]>()?;
85
86 let names = if input.peek(syn::token::Paren) {
88 let content;
89 syn::parenthesized!(content in input);
90 let mut out = Vec::new();
91 while !content.is_empty() {
92 out.push(content.parse::<Ident>()?);
93 if content.is_empty() {
94 break;
95 }
96 content.parse::<Token![,]>()?;
97 }
98 if out.is_empty() {
99 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100 }
101 out
102 } else {
103 vec![input.parse::<Ident>()?]
104 };
105
106 input.parse::<Token![=]>()?;
107 let value = parse_section_value(input, names.len())?;
108 Ok(Self { names, value })
109 }
110}
111
112struct SqlForgeInput {
114 db: Option<Type>,
115 result: ResultSpec,
116 force_scalar: bool,
117 sql: SqlTemplate,
118 params: ParamsSource,
119 sections: Vec<SectionAssign>,
120 batch: Option<Expr>,
121}
122
123#[derive(Clone)]
125struct ResultAssign {
126 name: Ident,
127 model: Type,
128 force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133 None,
135 Single(Box<Type>),
137 Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143 None,
144 Map(Vec<ParamAssign>),
146 Struct(Box<Expr>),
148}
149
150enum SqlTemplate {
152 Literal(LitStr),
153}
154
155impl SqlTemplate {
156 fn span(&self) -> Span {
157 match self {
158 Self::Literal(lit) => lit.span(),
159 }
160 }
161
162 fn into_segments(self) -> Result<Vec<Segment>, String> {
164 match self {
165 Self::Literal(lit) => parse_literal_segments(&lit.value()),
166 }
167 }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171 if input.peek(LitStr) {
172 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173 } else {
174 Err(input.error("sql_forge!: SQL template must be a string literal"))
175 }
176}
177
178#[derive(Clone)]
180enum Segment {
181 Text(String),
183 Section { name: String },
185 Batch { parts: Vec<TextPart> },
187}
188
189#[derive(Clone)]
191enum TextPart {
192 Lit(String),
194 Param { name: String, is_list: bool },
196}
197
198enum MapKind {
204 Results,
205 Params,
206 Sections,
207}
208
209fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
214 let fork = input.fork();
215 let content;
216 syn::parenthesized!(content in fork);
217
218 if content.is_empty() {
219 return Err(input.error("sql_forge!: map argument cannot be empty"));
220 }
221
222 if content.peek(Token![>]) {
223 Ok(Some(MapKind::Results))
224 } else if content.peek(Token![:]) {
225 Ok(Some(MapKind::Params))
226 } else if content.peek(Token![#]) {
227 Ok(Some(MapKind::Sections))
228 } else {
229 Ok(None)
230 }
231}
232
233impl Parse for ResultAssign {
234 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
235 input.parse::<Token![>]>()?;
236 let name: Ident = input.parse()?;
237 input.parse::<Token![=]>()?;
238 let (force_scalar, model) = if input.peek(kw::scalar) {
239 input.parse::<kw::scalar>()?;
240 (true, input.parse::<Type>()?)
241 } else {
242 (false, input.parse::<Type>()?)
243 };
244 Ok(Self {
245 name,
246 model,
247 force_scalar,
248 })
249 }
250}
251
252fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
253 let content;
254 syn::parenthesized!(content in input);
255
256 let mut results = Vec::new();
257 while !content.is_empty() {
258 results.push(content.parse::<ResultAssign>()?);
259 if content.is_empty() {
260 break;
261 }
262 content.parse::<Token![,]>()?;
263 }
264
265 if results.is_empty() {
266 return Err(input.error("sql_forge!: result map cannot be empty"));
267 }
268
269 Ok(results)
270}
271
272fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
273 let content;
274 syn::parenthesized!(content in input);
275
276 let mut params = Vec::new();
277 while !content.is_empty() {
278 params.push(content.parse::<ParamAssign>()?);
279 if content.is_empty() {
280 break;
281 }
282 content.parse::<Token![,]>()?;
283 }
284
285 Ok(params)
286}
287
288fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
289 let content;
290 syn::parenthesized!(content in input);
291
292 let mut sections = Vec::new();
293 while !content.is_empty() {
294 sections.push(content.parse::<SectionAssign>()?);
295 if content.is_empty() {
296 break;
297 }
298 content.parse::<Token![,]>()?;
299 }
300
301 Ok(sections)
302}
303
304fn parse_params_source_expr(
305 input: ParseStream<'_>,
306 allow_sections: bool,
307) -> syn::Result<ParamsSource> {
308 if input.peek(syn::token::Paren) {
309 match detect_parenthesized_map_kind(input)? {
310 Some(MapKind::Results) => Err(input
311 .error("sql_forge!: result maps are only allowed as the macro result argument")),
312 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
313 Some(MapKind::Sections) if allow_sections => {
314 Err(input.error("sql_forge!: section maps are not allowed here"))
315 }
316 Some(MapKind::Sections) => Err(input.error(
317 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
318 )),
319 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
320 }
321 } else {
322 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
323 }
324}
325
326fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327 if input.peek(syn::token::Paren) {
328 let fork = input.fork();
329 let content;
330 syn::parenthesized!(content in fork);
331
332 if let Ok(first_expr) = content.parse::<Expr>() {
333 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334 let _ = parse_params_source_expr(&content, false)?;
335 if content.peek(Token![,]) {
336 content.parse::<Token![,]>()?;
337 }
338 if content.is_empty() {
339 let content;
340 syn::parenthesized!(content in input);
341 let first_expr: Expr = content.parse()?;
342 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343 input.error("sql_forge!: section tuple must start with a string literal")
344 })?;
345 let span = first_expr.span();
346 content.parse::<Token![,]>()?;
347 let params = parse_params_source_expr(&content, false)?;
348 if content.peek(Token![,]) {
349 content.parse::<Token![,]>()?;
350 }
351 if !content.is_empty() {
352 return Err(content.error(
353 "sql_forge!: unexpected tokens after section-local parameter source",
354 ));
355 }
356 return Ok(SectionFragment { sql, span, params });
357 }
358 }
359 }
360 }
361
362 let expr: Expr = input.parse()?;
363 let sql = extract_lit_str(&expr).ok_or_else(|| {
364 input
365 .error("sql_forge!: section values must be string literals or (string literal, params)")
366 })?;
367 Ok(SectionFragment {
368 sql,
369 span: expr.span(),
370 params: ParamsSource::None,
371 })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
375 if input.peek(Token![match]) {
376 input.parse::<Token![match]>()?;
377 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
378 let content;
379 syn::braced!(content in input);
380 let mut arms = Vec::new();
381 while !content.is_empty() {
382 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
383 let guard = if content.peek(Token![if]) {
384 content.parse::<Token![if]>()?;
385 Some(content.parse::<Expr>()?)
386 } else {
387 None
388 };
389 content.parse::<Token![=>]>()?;
390 let value = parse_section_value(&content, width)?;
391 if content.peek(Token![,]) {
392 content.parse::<Token![,]>()?;
393 }
394 arms.push(SectionMatchArm { pat, guard, value });
395 }
396 return Ok(SectionValue::Match { expr, arms });
397 }
398
399 if width == 1 {
400 return Ok(SectionValue::Single(parse_section_fragment(input)?));
401 }
402
403 let content;
404 syn::parenthesized!(content in input);
405 let mut items = Vec::new();
406 while !content.is_empty() {
407 items.push(parse_section_value(&content, 1)?);
408 if content.is_empty() {
409 break;
410 }
411 content.parse::<Token![,]>()?;
412 }
413
414 if items.len() != width {
415 return Err(input.error(format!(
416 "sql_forge!: grouped section value must provide exactly {} items",
417 width,
418 )));
419 }
420
421 Ok(SectionValue::Grouped(items))
422}
423
424impl Parse for SqlForgeInput {
429 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
430 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
431 let sql = parse_sql_template(input)?;
432 (None, ResultSpec::None, false, sql)
433 } else if input.peek(kw::scalar) {
434 input.parse::<kw::scalar>()?;
435 let model: Type = input.parse()?;
436 input.parse::<Token![,]>()?;
437 let sql = parse_sql_template(input)?;
438 (None, ResultSpec::Single(Box::new(model)), true, sql)
439 } else if input.peek(syn::token::Paren) {
440 let result_map_kind = detect_parenthesized_map_kind(input)?;
441 match result_map_kind {
442 Some(MapKind::Results) => {
443 let result = ResultSpec::Group(parse_result_map(input)?);
444 input.parse::<Token![,]>()?;
445 let sql = parse_sql_template(input)?;
446 (None, result, false, sql)
447 }
448 _ => {
449 return Err(input.error(
450 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
451 ));
452 }
453 }
454 } else {
455 let first_ty: Type = input.parse()?;
456 input.parse::<Token![,]>()?;
457
458 if input.peek(LitStr) {
459 let model = first_ty;
460 let sql = parse_sql_template(input)?;
461 (None, ResultSpec::Single(Box::new(model)), false, sql)
462 } else if input.peek(kw::scalar) {
463 input.parse::<kw::scalar>()?;
464 let model: Type = input.parse()?;
465 input.parse::<Token![,]>()?;
466 let sql = parse_sql_template(input)?;
467 (
468 Some(first_ty),
469 ResultSpec::Single(Box::new(model)),
470 true,
471 sql,
472 )
473 } else if input.peek(syn::token::Paren)
474 && matches!(
475 detect_parenthesized_map_kind(input)?,
476 Some(MapKind::Results)
477 )
478 {
479 let result = ResultSpec::Group(parse_result_map(input)?);
480 input.parse::<Token![,]>()?;
481 let sql = parse_sql_template(input)?;
482 (Some(first_ty), result, false, sql)
483 } else {
484 let db = Some(first_ty);
485 let model: Type = input.parse()?;
486 input.parse::<Token![,]>()?;
487 let sql = parse_sql_template(input)?;
488 (db, ResultSpec::Single(Box::new(model)), false, sql)
489 }
490 };
491
492 let mut batch = None;
493 let mut params = ParamsSource::None;
494 let mut sections = Vec::new();
495 let mut seen_params = false;
496 let mut seen_sections = false;
497
498 if input.parse::<Token![,]>().is_ok() {
499 while !input.is_empty() {
500 if input.peek(Token![..]) {
501 if batch.is_some() {
502 return Err(
503 input.error("sql_forge!: only one batch source argument is allowed")
504 );
505 }
506 input.parse::<Token![..]>()?;
507 batch = Some(input.parse::<Expr>()?);
508 } else if input.peek(syn::token::Paren) {
509 match detect_parenthesized_map_kind(input)? {
510 Some(MapKind::Results) => {
511 return Err(input.error(
512 "sql_forge!: result maps are only allowed as the macro result argument",
513 ));
514 }
515 Some(MapKind::Params) => {
516 if seen_params {
517 return Err(
518 input.error("sql_forge!: only one parameter source is allowed")
519 );
520 }
521 params = ParamsSource::Map(parse_param_map(input)?);
522 seen_params = true;
523 }
524 Some(MapKind::Sections) => {
525 if seen_sections {
526 return Err(
527 input.error("sql_forge!: duplicate section map argument")
528 );
529 }
530 sections = parse_section_map(input)?;
531 seen_sections = true;
532 }
533 None => {
534 if seen_params {
535 return Err(
536 input.error("sql_forge!: only one parameter source is allowed")
537 );
538 }
539 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
540 seen_params = true;
541 }
542 }
543 } else {
544 if seen_params {
545 return Err(input.error("sql_forge!: only one parameter source is allowed"));
546 }
547 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548 seen_params = true;
549 }
550
551 if input.parse::<Token![,]>().is_ok() {
552 continue;
553 }
554 break;
555 }
556 }
557
558 if !input.is_empty() {
559 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
560 }
561
562 Ok(Self {
563 db,
564 result,
565 force_scalar,
566 sql,
567 params,
568 sections,
569 batch,
570 })
571 }
572}
573
574fn resolve_db_from_env() -> Result<Type, String> {
579 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
580 return syn::parse_str::<Type>(&val).map_err(|err| {
581 format!(
582 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
583 val, err
584 )
585 });
586 }
587
588 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
589 Ok(d) => d,
590 Err(_) => {
591 return Err(
592 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
593 or configure [package.metadata.sql_forge] in Cargo.toml"
594 .to_string(),
595 );
596 }
597 };
598 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
599
600 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
601 format!(
602 "sql_forge!: failed to read {}: {}",
603 manifest_path.display(),
604 err
605 )
606 })?;
607
608 let value: toml::Value = toml::from_str(&cargo_toml)
609 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
610
611 let db_str = value
612 .get("package")
613 .and_then(|v| v.get("metadata"))
614 .and_then(|v| v.get("sql_forge"))
615 .and_then(|v| v.get("db"))
616 .and_then(|v| v.as_str())
617 .ok_or({
618 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
619 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
620 })?;
621
622 syn::parse_str::<Type>(db_str).map_err(|err| {
623 format!(
624 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
625 db_str, err
626 )
627 })
628}
629
630fn uses_dollar_params(db: &Type) -> bool {
631 let Type::Path(type_path) = db else {
632 return false;
633 };
634 type_path
635 .path
636 .segments
637 .last()
638 .is_some_and(|s| s.ident == "Postgres")
639}
640
641fn is_builtin_scalar_type(ty: &Type) -> bool {
642 let Type::Path(type_path) = ty else {
643 return false;
644 };
645
646 if type_path.qself.is_some()
647 || type_path.path.leading_colon.is_some()
648 || type_path.path.segments.len() != 1
649 {
650 return false;
651 }
652
653 let ident = &type_path.path.segments[0].ident;
654 ident == "i8"
655 || ident == "i16"
656 || ident == "i32"
657 || ident == "i64"
658 || ident == "isize"
659 || ident == "u8"
660 || ident == "u16"
661 || ident == "u32"
662 || ident == "u64"
663 || ident == "usize"
664 || ident == "f32"
665 || ident == "f64"
666 || ident == "bool"
667 || ident == "String"
668}
669
670fn scalar_output_type(model: &Type) -> Option<&Type> {
671 if is_builtin_scalar_type(model) {
672 return Some(model);
673 }
674 None
675}
676
677fn push_text_segment(out: &mut Vec<Segment>, text: String) {
678 if text.is_empty() {
679 return;
680 }
681 match out.last_mut() {
682 Some(Segment::Text(existing)) => existing.push_str(&text),
683 _ => out.push(Segment::Text(text)),
684 }
685}
686
687fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
688 let mut out = Vec::new();
689 let mut text = String::new();
690 let mut chars = sql.chars().peekable();
691
692 while let Some(ch) = chars.next() {
693 if ch != '{' {
694 text.push(ch);
695 continue;
696 }
697
698 if chars.peek() == Some(&'(') {
699 push_text_segment(&mut out, std::mem::take(&mut text));
700
701 let mut paren_depth = 0u32;
702 let mut content = String::new();
703 let mut found_close = false;
704 for ch in chars.by_ref() {
705 if ch == '{' {
706 return Err(
707 "sql_forge!: nested braces not allowed inside batch section".to_string()
708 );
709 }
710 if ch == '}' {
711 if paren_depth != 0 {
712 return Err(
713 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
714 .to_string(),
715 );
716 }
717 found_close = true;
718 break;
719 }
720 if ch == '(' {
721 paren_depth += 1;
722 } else if ch == ')' {
723 if paren_depth == 0 {
724 return Err(
725 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
726 .to_string(),
727 );
728 }
729 paren_depth -= 1;
730 }
731 content.push(ch);
732 }
733 if !found_close {
734 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
735 }
736 let parts = parse_text_parts(&content);
737 for part in &parts {
738 if let TextPart::Param { is_list: true, .. } = part {
739 return Err(
740 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
741 batch sections; use plain parameters (:name) instead"
742 .to_string(),
743 );
744 }
745 }
746 out.push(Segment::Batch { parts });
747 continue;
748 }
749
750 if chars.peek() != Some(&'#') {
751 text.push(ch);
752 continue;
753 }
754
755 chars.next();
756 push_text_segment(&mut out, std::mem::take(&mut text));
757
758 let mut name = String::new();
759 loop {
760 let Some(next) = chars.next() else {
761 return Err("sql_forge!: section placeholder without closing }".to_string());
762 };
763 if next == '}' {
764 break;
765 }
766 name.push(next);
767 }
768
769 if name.is_empty() {
770 return Err("sql_forge!: empty section placeholder name".to_string());
771 }
772
773 out.push(Segment::Section { name });
774 }
775
776 push_text_segment(&mut out, text);
777 Ok(out)
778}
779
780fn is_ident_start(ch: char) -> bool {
785 ch == '_' || ch.is_ascii_alphabetic()
786}
787
788fn is_ident_continue(ch: char) -> bool {
789 is_ident_start(ch) || ch.is_ascii_digit()
790}
791
792fn sanitize_backticked_alias_ident(content: &str) -> String {
793 let mut split_at = content.len();
794 for (idx, ch) in content.char_indices() {
795 if ch == '!' || ch == '?' || ch == ':' {
796 split_at = idx;
797 break;
798 }
799 }
800
801 if split_at == content.len() {
802 return content.to_string();
803 }
804
805 let base = content[..split_at].trim_end();
806 if base.is_empty() {
807 content.to_string()
808 } else {
809 base.to_string()
810 }
811}
812
813fn sanitize_runtime_sql_text(text: &str) -> String {
814 let mut out = String::with_capacity(text.len());
815 let mut chars = text.chars().peekable();
816
817 while let Some(ch) = chars.next() {
818 if ch != '`' {
819 out.push(ch);
820 continue;
821 }
822
823 let mut content = String::new();
824 let mut closed = false;
825
826 for next in chars.by_ref() {
827 if next == '`' {
828 closed = true;
829 break;
830 }
831 content.push(next);
832 }
833
834 if closed {
835 out.push('`');
836 out.push_str(&sanitize_backticked_alias_ident(&content));
837 out.push('`');
838 } else {
839 out.push('`');
840 out.push_str(&content);
841 break;
842 }
843 }
844
845 out
846}
847
848fn parse_text_parts(text: &str) -> Vec<TextPart> {
849 let mut parts = Vec::new();
850 let mut last = 0usize;
851 let mut iter = text.char_indices().peekable();
852
853 while let Some((idx, ch)) = iter.next() {
854 if ch != ':' {
855 continue;
856 }
857
858 let Some(&(next_idx, next_ch)) = iter.peek() else {
859 continue;
860 };
861
862 if !is_ident_start(next_ch) {
863 continue;
864 }
865
866 if text[..idx].ends_with(':') {
867 continue;
868 }
869
870 if last < idx {
871 parts.push(TextPart::Lit(text[last..idx].to_string()));
872 }
873
874 iter.next();
875
876 let mut name = String::new();
877 name.push(next_ch);
878 let mut end = next_idx + next_ch.len_utf8();
879
880 while let Some(&(j, c)) = iter.peek() {
881 if is_ident_continue(c) {
882 name.push(c);
883 end = j + c.len_utf8();
884 iter.next();
885 } else {
886 break;
887 }
888 }
889
890 let mut is_list = false;
891 if text[end..].starts_with("[]") {
892 is_list = true;
893 end += 2;
894 }
895
896 parts.push(TextPart::Param { name, is_list });
897 last = end;
898 }
899
900 if last < text.len() {
901 parts.push(TextPart::Lit(text[last..].to_string()));
902 }
903
904 parts
905}
906
907fn render_validator_text(
908 text: &str,
909 use_dollar_params: bool,
910 param_offset: &mut usize,
911 list_count: usize,
912) -> (String, Vec<(String, bool)>) {
913 let mut out_sql = String::new();
914 let mut occurrences = Vec::new();
915
916 for part in parse_text_parts(text) {
917 match part {
918 TextPart::Lit(lit) => out_sql.push_str(&lit),
919 TextPart::Param { name, is_list } => {
920 if is_list && list_count > 1 {
921 let slots: Vec<String> = if use_dollar_params {
922 (0..list_count)
923 .map(|i| format!("${}", *param_offset + i + 1))
924 .collect()
925 } else {
926 (0..list_count).map(|_| "?".to_string()).collect()
927 };
928 if use_dollar_params {
929 *param_offset += list_count;
930 }
931 out_sql.push_str(&slots.join(", "));
932 } else if use_dollar_params {
933 *param_offset += 1;
934 write!(out_sql, "${}", *param_offset).unwrap();
935 } else {
936 out_sql.push('?');
937 }
938 occurrences.push((name, is_list));
939 }
940 }
941 }
942
943 (out_sql, occurrences)
944}
945
946fn strip_expr(expr: &Expr) -> &Expr {
947 match expr {
948 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
949 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
950 Expr::Block(ExprBlock { block, .. }) => {
951 if block.stmts.len() != 1 {
952 return expr;
953 }
954 match &block.stmts[0] {
955 Stmt::Expr(inner, None) => strip_expr(inner),
956 _ => expr,
957 }
958 }
959 _ => expr,
960 }
961}
962
963fn extract_lit_str(expr: &Expr) -> Option<String> {
964 match strip_expr(expr) {
965 Expr::Lit(ExprLit {
966 lit: Lit::Str(lit), ..
967 }) => Some(lit.value()),
968 _ => None,
969 }
970}
971
972fn result_flag_ident(name: &str) -> syn::Ident {
977 format_ident!("__enhanced_result_flag_{}", name)
978}
979
980fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
984 fn walk(stream: TokenStream2) -> TokenStream2 {
985 let mut out = TokenStream2::new();
986 let iter = stream.into_iter().peekable();
987
988 for token in iter {
989 match token {
990 TokenTree::Group(group) => {
991 if group.delimiter() == Delimiter::Brace {
992 let mut inner = group.stream().into_iter();
993 let first = inner.next();
994 let second = inner.next();
995 let third = inner.next();
996
997 if let (
998 Some(TokenTree::Punct(p)),
999 Some(TokenTree::Ident(name_ident)),
1000 None,
1001 ) = (first, second, third)
1002 {
1003 if p.as_char() == '>' {
1004 let ident = result_flag_ident(&name_ident.to_string());
1005 out.extend(std::iter::once(TokenTree::Ident(ident)));
1006 continue;
1007 }
1008 }
1009 }
1010
1011 let new_inner = walk(group.stream());
1012 let mut new_group = Group::new(group.delimiter(), new_inner);
1013 new_group.set_span(group.span());
1014 out.extend(std::iter::once(TokenTree::Group(new_group)));
1015 }
1016 other => out.extend(std::iter::once(other)),
1017 }
1018 }
1019
1020 out
1021 }
1022
1023 walk(input)
1024}
1025
1026fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1027 keys.iter()
1028 .map(|key| {
1029 let ident = result_flag_ident(key);
1030 let enabled = Some(key.as_str()) == active_key;
1031 quote! { let #ident: bool = #enabled; }
1032 })
1033 .collect()
1034}
1035
1036fn transpose_section_case_matrix(
1037 case_matrix: Vec<Vec<SectionFragment>>,
1038 width: usize,
1039) -> Result<Vec<Vec<SectionFragment>>, String> {
1040 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1041
1042 for row in case_matrix {
1043 if row.len() != width {
1044 return Err(
1045 "sql_forge!: grouped sections must return one item per section".to_string(),
1046 );
1047 }
1048 for (section_idx, fragment) in row.into_iter().enumerate() {
1049 per_section[section_idx].push(fragment);
1050 }
1051 }
1052
1053 Ok(per_section)
1054}
1055
1056fn collect_section_case_matrix(
1057 value: SectionValue,
1058 width: usize,
1059 active_key: Option<&str>,
1060) -> Result<Vec<Vec<SectionFragment>>, String> {
1061 match value {
1062 SectionValue::Single(fragment) => {
1063 if width != 1 {
1064 return Err(
1065 "sql_forge!: grouped sections must return one item per section".to_string(),
1066 );
1067 }
1068 Ok(vec![vec![fragment]])
1069 }
1070 SectionValue::Grouped(values) => {
1071 if values.len() != width {
1072 return Err(
1073 "sql_forge!: grouped sections must return one item per section".to_string(),
1074 );
1075 }
1076
1077 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1078 let mut nmax = 1usize;
1079
1080 for value in values {
1081 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1082 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1083 for mut row in item_matrix {
1084 let fragment = row.pop().ok_or_else(|| {
1085 "sql_forge!: grouped sections must return one item per section".to_string()
1086 })?;
1087 if !row.is_empty() {
1088 return Err(
1089 "sql_forge!: grouped sections must return one item per section"
1090 .to_string(),
1091 );
1092 }
1093 item_variants.push(fragment);
1094 }
1095 if item_variants.is_empty() {
1096 return Err("sql_forge!: section match must have at least one arm".to_string());
1097 }
1098 nmax = nmax.max(item_variants.len());
1099 variants_by_section.push(item_variants);
1100 }
1101
1102 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1103 for case_idx in 0..nmax {
1104 let mut row = Vec::<SectionFragment>::with_capacity(width);
1105 for variants in &variants_by_section {
1106 row.push(variants[case_idx % variants.len()].clone());
1107 }
1108 case_matrix.push(row);
1109 }
1110
1111 Ok(case_matrix)
1112 }
1113 SectionValue::Match { expr, arms } => {
1114 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1115
1116 if let Some(key) = expr_result_flag_key(&expr) {
1117 let target = active_key == Some(key.as_str());
1118 for arm in arms {
1119 if arm.guard.is_none() {
1120 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1121 continue;
1122 }
1123 }
1124 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1125 wrap_section_case_matrix_for_match_arm(
1126 &mut arm_cases,
1127 &expr,
1128 &arm.pat,
1129 arm.guard.as_ref(),
1130 );
1131 case_matrix.extend(arm_cases);
1132 }
1133 } else {
1134 for arm in arms {
1135 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1136 wrap_section_case_matrix_for_match_arm(
1137 &mut arm_cases,
1138 &expr,
1139 &arm.pat,
1140 arm.guard.as_ref(),
1141 );
1142 case_matrix.extend(arm_cases);
1143 }
1144 }
1145
1146 if case_matrix.is_empty() {
1147 return Err("sql_forge!: section match must have at least one arm".to_string());
1148 }
1149
1150 Ok(case_matrix)
1151 }
1152 }
1153}
1154
1155fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1158 let match_expr = match_expr.clone();
1159 let pat = pat.clone();
1160 let pattern_binds_values = match &pat {
1161 Pat::Ident(_) => true,
1162 Pat::Or(pat_or) => pat_or
1163 .cases
1164 .iter()
1165 .any(|case| matches!(case, Pat::Ident(_))),
1166 Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1167 Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1168 Pat::Slice(pat_slice) => pat_slice
1169 .elems
1170 .iter()
1171 .any(|elem| matches!(elem, Pat::Ident(_))),
1172 Pat::Struct(pat_struct) => pat_struct
1173 .fields
1174 .iter()
1175 .any(|field| matches!(*field.pat, Pat::Ident(_))),
1176 Pat::Tuple(pat_tuple) => pat_tuple
1177 .elems
1178 .iter()
1179 .any(|elem| matches!(elem, Pat::Ident(_))),
1180 Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1181 .elems
1182 .iter()
1183 .any(|elem| matches!(elem, Pat::Ident(_))),
1184 Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1185 _ => false,
1186 };
1187
1188 if pattern_binds_values {
1189 let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1190 .into_iter()
1191 .map(|ident| quote! { let _ = &#ident; })
1192 .collect();
1193 if let Some(guard) = guard.cloned() {
1194 parse_quote! {
1195 match &(#match_expr) {
1196 #pat if #guard => { #( #pat_refs )* #expr },
1197 _ => unreachable!("sql_forge!: validator arm mismatch"),
1198 }
1199 }
1200 } else {
1201 parse_quote! {
1202 match &(#match_expr) {
1203 #pat => { #( #pat_refs )* #expr },
1204 _ => unreachable!("sql_forge!: validator arm mismatch"),
1205 }
1206 }
1207 }
1208 } else if let Some(guard) = guard.cloned() {
1209 parse_quote! {
1210 match &(#match_expr) {
1211 #pat if #guard => { &(#expr) },
1212 _ => unreachable!("sql_forge!: validator arm mismatch"),
1213 }
1214 }
1215 } else {
1216 parse_quote! {
1217 match &(#match_expr) {
1218 #pat => { &(#expr) },
1219 _ => unreachable!("sql_forge!: validator arm mismatch"),
1220 }
1221 }
1222 }
1223}
1224
1225fn wrap_params_source_for_match_arm(
1226 params: &mut ParamsSource,
1227 match_expr: &Expr,
1228 pat: &Pat,
1229 guard: Option<&Expr>,
1230) {
1231 match params {
1232 ParamsSource::None => {}
1233 ParamsSource::Map(entries) => {
1234 for entry in entries {
1235 entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1236 }
1237 }
1238 ParamsSource::Struct(expr) => {
1239 **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1240 }
1241 }
1242}
1243
1244fn wrap_section_case_matrix_for_match_arm(
1245 case_matrix: &mut [Vec<SectionFragment>],
1246 match_expr: &Expr,
1247 pat: &Pat,
1248 guard: Option<&Expr>,
1249) {
1250 for row in case_matrix {
1251 for fragment in row {
1252 wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1253 }
1254 }
1255}
1256
1257fn collect_section_variants(
1267 value: SectionValue,
1268 width: usize,
1269) -> Result<Vec<Vec<SectionFragment>>, String> {
1270 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1271}
1272
1273fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1274 match strip_expr(expr) {
1275 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1276 let name = path.path.segments[0].ident.to_string();
1277 name.strip_prefix("__enhanced_result_flag_")
1278 .map(|v| v.to_string())
1279 }
1280 _ => None,
1281 }
1282}
1283
1284fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1285 match pat {
1286 Pat::Lit(expr_lit) => match &expr_lit.lit {
1287 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1288 _ => None,
1289 },
1290 Pat::Wild(_) => Some(true),
1291 _ => None,
1292 }
1293}
1294
1295fn collect_section_variants_for_result(
1300 value: SectionValue,
1301 width: usize,
1302 active_key: Option<&str>,
1303) -> Result<Vec<Vec<SectionFragment>>, String> {
1304 transpose_section_case_matrix(
1305 collect_section_case_matrix(value, width, active_key)?,
1306 width,
1307 )
1308}
1309
1310fn build_param_bindings(
1318 params: &ParamsSource,
1319 used_param_names: &[String],
1320 prefix: &str,
1321 for_validator: bool,
1322 enforce_usage_check: bool,
1323) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1324 let mut declared_params = HashMap::<String, syn::Ident>::new();
1325 let mut bindings = Vec::<TokenStream2>::new();
1326
1327 match params {
1328 ParamsSource::None => {}
1329 ParamsSource::Map(entries) => {
1330 for entry in entries {
1331 let key = entry.name.to_string();
1332 if declared_params.contains_key(&key) {
1333 return Err(syn::Error::new(
1334 entry.name.span(),
1335 "sql_forge!: duplicated parameter mapping",
1336 )
1337 .to_compile_error()
1338 .into());
1339 }
1340 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1341 return Err(syn::Error::new(
1342 entry.name.span(),
1343 format!(
1344 "sql_forge!: parameter :{} is unused in the SQL template",
1345 key,
1346 ),
1347 )
1348 .to_compile_error()
1349 .into());
1350 }
1351 let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1352 let expr = &entry.expr;
1353 if for_validator {
1354 bindings.push(quote! {
1355 let #local_ident = &(#expr);
1356 });
1357 } else {
1358 bindings.push(quote! {
1359 let #local_ident = #expr;
1360 });
1361 }
1362 declared_params.insert(key, local_ident);
1363 }
1364 }
1365 ParamsSource::Struct(expr) => {
1366 let source_ident = format_ident!("__enhanced_source_{}", prefix);
1367 bindings.push(quote! {
1368 let #source_ident = &(#expr);
1369 });
1370 for name in used_param_names {
1371 let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1372 let field_ident = format_ident!("{}", name);
1373 if for_validator {
1374 bindings.push(quote! {
1375 let #local_ident = &#source_ident.#field_ident;
1376 });
1377 } else {
1378 bindings.push(quote! {
1379 let #local_ident = #source_ident.#field_ident;
1380 });
1381 }
1382 declared_params.insert(name.to_string(), local_ident);
1383 }
1384 }
1385 }
1386
1387 Ok((declared_params, bindings))
1388}
1389
1390struct ValidatorRenderContext<'a> {
1391 local_params: &'a HashMap<String, syn::Ident>,
1392 top_level_params: &'a HashMap<String, syn::Ident>,
1393 allow_top_level_fallback: bool,
1394 use_dollar_params: bool,
1395 sql_span: Span,
1396 list_count: usize,
1397}
1398
1399fn render_validator_args(
1404 sql: &str,
1405 param_offset: &mut usize,
1406 arg_index: &mut usize,
1407 context: &ValidatorRenderContext<'_>,
1408) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1409 let (rendered_sql, occurrences) = render_validator_text(
1410 sql,
1411 context.use_dollar_params,
1412 param_offset,
1413 context.list_count,
1414 );
1415
1416 let mut setup = Vec::<TokenStream2>::new();
1417 let mut args = Vec::<TokenStream2>::new();
1418
1419 for (name, is_list) in occurrences {
1420 let local_ident = if context.allow_top_level_fallback {
1421 context
1422 .local_params
1423 .get(&name)
1424 .or_else(|| context.top_level_params.get(&name))
1425 } else {
1426 context.local_params.get(&name)
1427 };
1428
1429 let Some(local_ident) = local_ident else {
1430 return Err(syn::Error::new(
1431 context.sql_span,
1432 format!("sql_forge!: parameter :{} has no mapping", name),
1433 )
1434 .to_compile_error()
1435 .into());
1436 };
1437
1438 if is_list {
1439 for _ in 0..context.list_count {
1440 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1441 *arg_index += 1;
1442 if context.use_dollar_params {
1443 setup.push(quote! {
1444 let #value_ident = sql_forge::sql_forge_validator_value(
1445 (#local_ident)
1446 .as_slice()
1447 .first()
1448 .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1449 );
1450 });
1451 } else {
1452 setup.push(quote! {
1453 let #value_ident = (#local_ident)
1454 .as_slice()
1455 .first()
1456 .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1457 });
1458 }
1459 args.push(quote! { #value_ident });
1460 }
1461 } else {
1462 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1463 *arg_index += 1;
1464 if context.use_dollar_params {
1465 setup.push(quote! {
1466 let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1467 });
1468 } else {
1469 setup.push(quote! {
1470 let #value_ident = #local_ident;
1471 });
1472 }
1473 args.push(quote! { #value_ident });
1474 }
1475 }
1476
1477 Ok((rendered_sql, setup, args))
1478}
1479
1480fn render_runtime_fragment(
1487 fragment: &SectionFragment,
1488 local_params: &HashMap<String, syn::Ident>,
1489) -> Result<TokenStream2, TokenStream> {
1490 let mut steps = Vec::<TokenStream2>::new();
1491
1492 for part in parse_text_parts(&fragment.sql) {
1493 match part {
1494 TextPart::Lit(lit) => {
1495 let lit_str = LitStr::new(&lit, fragment.span);
1496 steps.push(quote! { __builder.push(#lit_str); });
1497 }
1498 TextPart::Param { name, is_list } => {
1499 let Some(local_ident) = local_params.get(&name) else {
1500 return Err(syn::Error::new(
1501 fragment.span,
1502 format!("sql_forge!: parameter :{} has no mapping", name),
1503 )
1504 .to_compile_error()
1505 .into());
1506 };
1507
1508 if is_list {
1509 steps.push(quote! {
1510 let __enhanced_values = #local_ident;
1511 let mut __separated = __builder.separated(", ");
1512 for __value in __enhanced_values {
1513 __separated.push_bind(__value);
1514 }
1515 });
1516 } else {
1517 steps.push(quote! {
1518 __builder.push_bind(#local_ident);
1519 });
1520 }
1521 }
1522 }
1523 }
1524
1525 Ok(quote! { #( #steps )* })
1526}
1527
1528fn is_pat_binding(ident: &Ident) -> bool {
1529 let name = ident.to_string();
1530 !name.is_empty()
1531 && name
1532 .chars()
1533 .next()
1534 .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1535}
1536
1537fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1538 let mut names = Vec::new();
1539 fn walk(p: &Pat, names: &mut Vec<Ident>) {
1540 match p {
1541 Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1542 Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1543 Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1544 Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1545 Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1546 Pat::Paren(pp) => walk(&pp.pat, names),
1547 Pat::Reference(pr) => walk(&pr.pat, names),
1548 Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1549 Pat::Type(pt) => walk(&pt.pat, names),
1550 _ => {}
1551 }
1552 }
1553 walk(pat, &mut names);
1554 names
1555}
1556
1557fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1558 match value {
1559 SectionValue::Single(f) => {
1560 if collect_used_param_names_in_sql(&f.sql)
1561 .iter()
1562 .any(|n| n == name)
1563 {
1564 return true;
1565 }
1566 if let ParamsSource::Map(entries) = &f.params {
1567 for e in entries {
1568 let expr = &e.expr;
1569 let expr_str = quote! { #expr }.to_string();
1570 if expr_str.trim() == name {
1571 return true;
1572 }
1573 }
1574 }
1575 false
1576 }
1577 SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1578 SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1579 let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1580 .into_iter()
1581 .map(|i| i.to_string())
1582 .collect();
1583 if pat_vars.contains(name) {
1584 false
1585 } else {
1586 section_value_refers_to(&arm.value, name)
1587 }
1588 }),
1589 }
1590}
1591
1592fn build_section_runtime_action(
1593 value: &SectionValue,
1594 section_idx: usize,
1595 prefix: &str,
1596) -> Result<TokenStream2, TokenStream> {
1597 match value {
1598 SectionValue::Single(fragment) => {
1599 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1600 let (local_params, bindings) =
1601 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1602 let body = render_runtime_fragment(fragment, &local_params)?;
1603 Ok(quote! {{ #( #bindings )* #body }})
1604 }
1605 SectionValue::Grouped(fragments) => build_section_runtime_action(
1606 &fragments[section_idx],
1607 0,
1608 &format!("{}_grouped_{}", prefix, section_idx),
1609 ),
1610 SectionValue::Match { expr, arms } => {
1611 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1612 .iter()
1613 .enumerate()
1614 .map(|(arm_idx, arm)| {
1615 let pat = &arm.pat;
1616 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1617 let body = build_section_runtime_action(
1618 &arm.value,
1619 section_idx,
1620 &format!("{}_{}", prefix, arm_idx),
1621 )?;
1622 let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1623 .into_iter()
1624 .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1625 .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1626 .collect();
1627 Ok::<TokenStream2, TokenStream>(quote! {
1628 #pat #guard_tokens => {
1629 #( #noop_refs )*
1630 #body
1631 }
1632 })
1633 })
1634 .collect();
1635 let arm_tokens = arm_tokens?;
1636 Ok(quote! {
1637 match #expr {
1638 #( #arm_tokens ),*
1639 }
1640 })
1641 }
1642 }
1643}
1644
1645fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1646 let mut names = Vec::new();
1647 let mut seen = HashSet::<String>::new();
1648
1649 for segment in segments {
1650 match segment {
1651 Segment::Text(text) => {
1652 for name in collect_used_param_names_in_sql(text) {
1653 if seen.insert(name.clone()) {
1654 names.push(name);
1655 }
1656 }
1657 }
1658 Segment::Batch { parts } => {
1659 for part in parts {
1660 if let TextPart::Param { name, .. } = part {
1661 if seen.insert(name.clone()) {
1662 names.push(name.clone());
1663 }
1664 }
1665 }
1666 }
1667 _ => {}
1668 }
1669 }
1670
1671 names
1672}
1673
1674fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1675 let mut names = Vec::new();
1676 let mut seen = HashSet::<String>::new();
1677 for part in parse_text_parts(sql) {
1678 if let TextPart::Param { name, .. } = part {
1679 if seen.insert(name.to_string()) {
1680 names.push(name);
1681 }
1682 }
1683 }
1684 names
1685}
1686
1687#[proc_macro]
1913#[allow(clippy::too_many_lines)]
1914pub fn sql_forge(input: TokenStream) -> TokenStream {
1915 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1917 let SqlForgeInput {
1918 db,
1919 result,
1920 force_scalar,
1921 sql,
1922 params,
1923 sections,
1924 batch,
1925 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1926 Ok(v) => v,
1927 Err(err) => return err.to_compile_error().into(),
1928 };
1929
1930 let db = match db {
1932 Some(db) => db,
1933 None => match resolve_db_from_env() {
1934 Ok(db) => db,
1935 Err(msg) => {
1936 return syn::Error::new(Span::call_site(), msg)
1937 .to_compile_error()
1938 .into();
1939 }
1940 },
1941 };
1942
1943 let use_dollar_params = uses_dollar_params(&db);
1944 let is_sqlite = if let syn::Type::Path(type_path) = &db {
1945 type_path
1946 .path
1947 .segments
1948 .last()
1949 .is_some_and(|s| s.ident == "Sqlite")
1950 } else {
1951 false
1952 };
1953 let list_count: usize = if is_sqlite { 1 } else { 3 };
1954
1955 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1959 ResultSpec::None => {
1960 vec![(None, None, None)]
1961 }
1962 ResultSpec::Single(ref model) => {
1963 let model_ty = (**model).clone();
1964 let scalar = if force_scalar {
1965 Some(model_ty.clone())
1966 } else {
1967 scalar_output_type(model.as_ref()).cloned()
1968 };
1969 vec![(None, Some(model_ty), scalar)]
1970 }
1971 ResultSpec::Group(ref cases) => {
1972 if force_scalar {
1973 return syn::Error::new(
1974 Span::call_site(),
1975 "sql_forge!: scalar mode is not supported for grouped result maps",
1976 )
1977 .to_compile_error()
1978 .into();
1979 }
1980
1981 let mut out = Vec::new();
1982 let mut seen = HashSet::new();
1983 for case in cases {
1984 let key = case.name.to_string();
1985 if !seen.insert(key.clone()) {
1986 return syn::Error::new(
1987 case.name.span(),
1988 "sql_forge!: duplicated key in result map",
1989 )
1990 .to_compile_error()
1991 .into();
1992 }
1993
1994 let model = case.model.clone();
1995 let scalar = if case.force_scalar {
1996 Some(model.clone())
1997 } else {
1998 scalar_output_type(&case.model).cloned()
1999 };
2000 out.push((Some(key), Some(model), scalar));
2001 }
2002 out
2003 }
2004 };
2005 let group_result_keys: Vec<String> = result_cases
2006 .iter()
2007 .filter_map(|(key, _, _)| key.as_ref().cloned())
2008 .collect();
2009 let is_grouped_result = !group_result_keys.is_empty();
2010 let sql_span = sql.span();
2011
2012 let segments = match sql.into_segments() {
2014 Ok(segments) => segments,
2015 Err(msg) => {
2016 return syn::Error::new(sql_span, msg).to_compile_error().into();
2017 }
2018 };
2019
2020 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2021 match (&batch, has_batch_segment) {
2022 (None, true) => {
2023 return syn::Error::new(
2024 sql_span,
2025 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2026 was provided"
2027 )
2028 .to_compile_error()
2029 .into();
2030 }
2031 (Some(_), false) => {
2032 return syn::Error::new(
2033 sql_span,
2034 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2035 batch section",
2036 )
2037 .to_compile_error()
2038 .into();
2039 }
2040 _ => {}
2041 }
2042
2043 let used_param_names = collect_used_param_names(&segments);
2044
2045 let batch_param_names: std::collections::HashSet<String> = segments
2050 .iter()
2051 .filter_map(|s| {
2052 if let Segment::Batch { parts } = s {
2053 Some(parts.iter().filter_map(|p| {
2054 if let TextPart::Param { name, .. } = p {
2055 Some(name.clone())
2056 } else {
2057 None
2058 }
2059 }))
2060 } else {
2061 None
2062 }
2063 })
2064 .flatten()
2065 .collect();
2066 let top_level_used_names: Vec<String> = used_param_names
2067 .iter()
2068 .filter(|n| !batch_param_names.contains(*n))
2069 .cloned()
2070 .collect();
2071
2072 let (declared_params, validator_param_bindings) =
2074 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
2075 Ok(v) => v,
2076 Err(err) => return err,
2077 };
2078
2079 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2080
2081 for assign in §ions {
2083 let SectionAssign { names, value } = assign;
2084
2085 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2087 for (section_idx, name_ident) in names.iter().enumerate() {
2088 let name = name_ident.to_string();
2089 if runtime_section_actions.contains_key(&name) {
2090 return syn::Error::new(
2091 name_ident.span(),
2092 "sql_forge!: duplicated section mapping",
2093 )
2094 .to_compile_error()
2095 .into();
2096 }
2097 let action = match build_section_runtime_action(
2098 value,
2099 section_idx,
2100 &format!("section_{}", name),
2101 ) {
2102 Ok(action) => action,
2103 Err(err) => return err,
2104 };
2105 named_actions.push((name, action));
2106 }
2107
2108 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2110 return syn::Error::new(names[0].span(), msg)
2111 .to_compile_error()
2112 .into();
2113 }
2114
2115 for (name, action) in named_actions {
2116 runtime_section_actions.insert(name, action);
2117 }
2118 }
2119
2120 let sql_section_names: std::collections::HashSet<&str> = segments
2121 .iter()
2122 .filter_map(|seg| {
2123 if let Segment::Section { name } = seg {
2124 Some(name.as_str())
2125 } else {
2126 None
2127 }
2128 })
2129 .collect();
2130 for name in runtime_section_actions.keys() {
2131 if !sql_section_names.contains(name.as_str()) {
2132 return syn::Error::new(
2133 sql_span,
2134 format!(
2135 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2136 name, name,
2137 ),
2138 )
2139 .to_compile_error()
2140 .into();
2141 }
2142 }
2143
2144 let mut generated_query_defs = Vec::<TokenStream2>::new();
2146 let mut generated_query_values = Vec::<TokenStream2>::new();
2147 let mut group_field_defs = Vec::<TokenStream2>::new();
2148 let mut group_method_defs = Vec::<TokenStream2>::new();
2149 let mut group_field_idents = Vec::<syn::Ident>::new();
2150 let mut group_field_tys = Vec::<TokenStream2>::new();
2151 let mut group_trait_impls = Vec::<TokenStream2>::new();
2152
2153 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2154
2155 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2156 let suffix = result_key.as_deref().unwrap_or("single");
2157 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2158 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2159
2160 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2161
2162 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2163 for assign in §ions {
2164 let SectionAssign { names, value } = assign;
2165 let variants_by_section = match collect_section_variants_for_result(
2166 value.clone(),
2167 names.len(),
2168 result_key.as_deref(),
2169 ) {
2170 Ok(v) => v,
2171 Err(msg) => {
2172 return syn::Error::new(names[0].span(), msg)
2173 .to_compile_error()
2174 .into();
2175 }
2176 };
2177
2178 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2179 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2180 }
2181 }
2182
2183 let mut nmax = 1usize;
2184 for segment in &segments {
2185 if let Segment::Section { name } = segment {
2186 if let Some(variants) = section_variants_for_validation.get(name) {
2187 if variants.is_empty() {
2188 return syn::Error::new(
2189 sql_span,
2190 format!("sql_forge!: section {{#{}}} has no possible variants", name),
2191 )
2192 .to_compile_error()
2193 .into();
2194 }
2195 nmax = nmax.max(variants.len());
2196 } else {
2197 return syn::Error::new(
2198 sql_span,
2199 format!("sql_forge!: section {{#{}}} has no mapping", name),
2200 )
2201 .to_compile_error()
2202 .into();
2203 }
2204 }
2205 }
2206
2207 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2208 for case_idx in 0..nmax {
2209 let mut sql_case = String::new();
2210 let mut case_setup = Vec::<TokenStream2>::new();
2211 let mut case_args = Vec::<TokenStream2>::new();
2212 let mut param_offset = 0usize;
2213 let mut arg_index = 0usize;
2214 let empty_params = HashMap::<String, syn::Ident>::new();
2215 let root_validator_context = ValidatorRenderContext {
2216 local_params: &empty_params,
2217 top_level_params: &declared_params,
2218 allow_top_level_fallback: true,
2219 use_dollar_params,
2220 sql_span,
2221 list_count,
2222 };
2223
2224 for segment in &segments {
2225 match segment {
2226 Segment::Text(text) => {
2227 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2228 text,
2229 &mut param_offset,
2230 &mut arg_index,
2231 &root_validator_context,
2232 ) {
2233 Ok(value) => value,
2234 Err(err) => return err,
2235 };
2236 sql_case.push_str(&chunk_sql);
2237 case_setup.extend(chunk_setup);
2238 case_args.extend(chunk_args);
2239 }
2240 Segment::Section { name } => {
2241 let Some(variants) = section_variants_for_validation.get(name) else {
2242 return syn::Error::new(
2243 sql_span,
2244 format!("sql_forge!: section {{#{}}} has no mapping", name),
2245 )
2246 .to_compile_error()
2247 .into();
2248 };
2249
2250 let fragment = &variants[case_idx % variants.len()];
2251 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2252 let (local_params, bindings) = match build_param_bindings(
2253 &fragment.params,
2254 &used_param_names,
2255 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2256 true,
2257 true,
2258 ) {
2259 Ok(value) => value,
2260 Err(err) => return err,
2261 };
2262 let section_validator_context = ValidatorRenderContext {
2263 local_params: &local_params,
2264 top_level_params: &declared_params,
2265 allow_top_level_fallback: false,
2266 use_dollar_params,
2267 sql_span: fragment.span,
2268 list_count,
2269 };
2270 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2271 &fragment.sql,
2272 &mut param_offset,
2273 &mut arg_index,
2274 §ion_validator_context,
2275 ) {
2276 Ok(value) => value,
2277 Err(err) => return err,
2278 };
2279 sql_case.push_str(&chunk_sql);
2280 case_setup.extend(bindings);
2281 case_setup.extend(chunk_setup);
2282 case_args.extend(chunk_args);
2283 }
2284 Segment::Batch { parts } => {
2285 let mut first = true;
2286 for _ in 0..list_count {
2287 let sep = if first { "" } else { ", " };
2288 first = false;
2289 sql_case.push_str(sep);
2290 for tp in parts {
2291 match tp {
2292 TextPart::Lit(lit) => sql_case.push_str(lit),
2293 TextPart::Param { name, .. } => {
2294 if let Some(batch_expr) = &batch {
2295 let field_ident = format_ident!("{}", name);
2296 if use_dollar_params {
2297 param_offset += 1;
2298 write!(sql_case, "${}", param_offset).unwrap();
2299 } else {
2300 sql_case.push('?');
2301 }
2302 case_args.push(quote! { #batch_expr[0].#field_ident });
2303 } else if use_dollar_params {
2304 param_offset += 1;
2305 write!(sql_case, "${}", param_offset).unwrap();
2306 } else {
2307 sql_case.push('?');
2308 }
2309 }
2310 }
2311 }
2312 }
2313 }
2314 }
2315 }
2316
2317 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2318 }
2319
2320 let mut validator_invocations = Vec::<TokenStream2>::new();
2321 for (sql_lit, case_setup, args) in &validator_cases {
2322 if model_opt.is_none() {
2323 if args.is_empty() {
2324 validator_invocations.push(quote! {
2325 {
2326 #( #case_setup )*
2327 let _ = sqlx::query_scalar!(
2328 #sql_lit,
2329 );
2330 }
2331 });
2332 } else {
2333 validator_invocations.push(quote! {
2334 {
2335 #( #case_setup )*
2336 let _ = sqlx::query_scalar!(
2337 #sql_lit,
2338 #( #args ),*
2339 );
2340 }
2341 });
2342 }
2343 } else if let Some(scalar_ty) = scalar_model_ty {
2344 if args.is_empty() {
2345 validator_invocations.push(quote! {
2346 {
2347 #( #case_setup )*
2348 let _ = sqlx::query_scalar!(
2349 #sql_lit,
2350 );
2351 }
2352 });
2353 } else {
2354 validator_invocations.push(quote! {
2355 {
2356 #( #case_setup )*
2357 let _ = sqlx::query_scalar!(
2358 #sql_lit,
2359 #( #args ),*
2360 );
2361 }
2362 });
2363 }
2364 let _ = scalar_ty;
2365 } else if args.is_empty() {
2366 validator_invocations.push(quote! {
2367 {
2368 #( #case_setup )*
2369 let _ = sqlx::query_as!(
2370 __EnhancedModel,
2371 #sql_lit,
2372 );
2373 }
2374 });
2375 } else {
2376 validator_invocations.push(quote! {
2377 {
2378 #( #case_setup )*
2379 let _ = sqlx::query_as!(
2380 __EnhancedModel,
2381 #sql_lit,
2382 #( #args ),*
2383 );
2384 }
2385 });
2386 }
2387 }
2388
2389 let model_alias = if let Some(model) = model_opt {
2390 if scalar_model_ty.is_none() {
2391 quote! { type __EnhancedModel = #model; }
2392 } else {
2393 quote! {}
2394 }
2395 } else {
2396 quote! {}
2397 };
2398 grouped_validator_invocations.push(quote! {
2399 {
2400 #( #flag_bindings )*
2401 #model_alias
2402 #( #validator_invocations )*
2403 }
2404 });
2405
2406 let (runtime_declared_params, runtime_param_bindings) =
2407 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2408 Ok(v) => v,
2409 Err(err) => return err,
2410 };
2411
2412 let mut runtime_steps = Vec::<TokenStream2>::new();
2413 for (seg_idx, segment) in segments.iter().enumerate() {
2414 match segment {
2415 Segment::Text(text) => {
2416 for part in parse_text_parts(text) {
2417 match part {
2418 TextPart::Lit(lit) => {
2419 let lit = sanitize_runtime_sql_text(&lit);
2420 let lit_str = LitStr::new(&lit, sql_span);
2421 runtime_steps.push(quote! {
2422 __builder.push(#lit_str);
2423 });
2424 }
2425 TextPart::Param { name, is_list } => {
2426 let Some(local_ident) = runtime_declared_params.get(&name) else {
2427 return syn::Error::new(
2428 sql_span,
2429 format!("sql_forge!: parameter :{} has no mapping", name),
2430 )
2431 .to_compile_error()
2432 .into();
2433 };
2434
2435 if is_list {
2436 runtime_steps.push(quote! {
2437 let __enhanced_values = #local_ident;
2438 let mut __separated = __builder.separated(", ");
2439 for __value in __enhanced_values {
2440 __separated.push_bind(__value);
2441 }
2442 });
2443 } else {
2444 runtime_steps.push(quote! {
2445 __builder.push_bind(#local_ident);
2446 });
2447 }
2448 }
2449 }
2450 }
2451 }
2452 Segment::Section { name } => {
2453 let Some(section_action) = runtime_section_actions.get(name) else {
2454 let _ = seg_idx;
2455 return syn::Error::new(
2456 sql_span,
2457 format!("sql_forge!: section {{#{}}} has no mapping", name),
2458 )
2459 .to_compile_error()
2460 .into();
2461 };
2462 runtime_steps.push(quote! {
2463 #section_action
2464 });
2465 }
2466 Segment::Batch { parts } => {
2467 if let Some(batch_expr) = &batch {
2468 let mut body = Vec::<TokenStream2>::new();
2469 for part in parts {
2470 match part {
2471 TextPart::Lit(lit) => {
2472 let lit_str = LitStr::new(lit, sql_span);
2473 body.push(quote! {
2474 __builder.push(#lit_str);
2475 });
2476 }
2477 TextPart::Param { name, .. } => {
2478 let field_ident = format_ident!("{}", name);
2479 body.push(quote! {
2480 __builder.push_bind(__item.#field_ident);
2481 });
2482 }
2483 }
2484 }
2485 runtime_steps.push(quote! {
2486 {
2487 let mut __first = true;
2488 for __item in #batch_expr {
2489 if !__first {
2490 __builder.push(", ");
2491 }
2492 __first = false;
2493 #( #body )*
2494 }
2495 }
2496 });
2497 }
2498 }
2499 }
2500 }
2501
2502 let exec_methods = if model_opt.is_none() {
2503 quote! {
2504 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2505 where
2506 E: sqlx::Executor<'e, Database = #db>,
2507 {
2508 self.inner.build().execute(executor).await
2509 }
2510 }
2511 } else if let Some(scalar_ty) = scalar_model_ty {
2512 quote! {
2513 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2514 where
2515 E: sqlx::Executor<'e, Database = #db>,
2516 {
2517 self.inner
2518 .build_query_scalar::<#scalar_ty>()
2519 .fetch_all(executor)
2520 .await
2521 }
2522
2523 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2524 where
2525 E: sqlx::Executor<'e, Database = #db>,
2526 {
2527 self.inner
2528 .build_query_scalar::<#scalar_ty>()
2529 .fetch_one(executor)
2530 .await
2531 }
2532
2533 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2534 where
2535 E: sqlx::Executor<'e, Database = #db>,
2536 {
2537 self.inner
2538 .build_query_scalar::<#scalar_ty>()
2539 .fetch_optional(executor)
2540 .await
2541 }
2542
2543 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2544 where
2545 E: sqlx::Executor<'e, Database = #db>,
2546 {
2547 self.inner.build().execute(executor).await
2548 }
2549 }
2550 } else {
2551 let model = model_opt.as_ref().unwrap();
2552 quote! {
2553 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2554 where
2555 E: sqlx::Executor<'e, Database = #db>,
2556 {
2557 self.inner.build_query_as::<#model>().fetch_all(executor).await
2558 }
2559
2560 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2561 where
2562 E: sqlx::Executor<'e, Database = #db>,
2563 {
2564 self.inner.build_query_as::<#model>().fetch_one(executor).await
2565 }
2566
2567 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2568 where
2569 E: sqlx::Executor<'e, Database = #db>,
2570 {
2571 self.inner
2572 .build_query_as::<#model>()
2573 .fetch_optional(executor)
2574 .await
2575 }
2576
2577 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2578 where
2579 E: sqlx::Executor<'e, Database = #db>,
2580 {
2581 self.inner.build().execute(executor).await
2582 }
2583 }
2584 };
2585
2586 let final_type: TokenStream2 = if let Some(model) = model_opt {
2587 if let Some(scalar_ty) = scalar_model_ty {
2588 quote! { #scalar_ty }
2589 } else {
2590 quote! { #model }
2591 }
2592 } else {
2593 quote! {}
2594 };
2595 let trait_impl = if model_opt.is_none() {
2596 quote! {
2597 impl<'args> sql_forge::SqlForgeQueryExecute
2598 for #query_ident<'args>
2599 {
2600 type Db = #db;
2601
2602 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2603 where
2604 Self: Sized + 'e,
2605 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2606 #db: 'e,
2607 {
2608 #query_ident::execute(self, executor)
2609 }
2610 }
2611 }
2612 } else {
2613 quote! {
2614 impl<'args> sql_forge::SqlForgeQuery<#final_type>
2615 for #query_ident<'args>
2616 {
2617 type Db = #db;
2618
2619 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2620 where
2621 Self: Sized + 'e,
2622 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2623 #db: 'e,
2624 {
2625 #query_ident::fetch_all(self, executor)
2626 }
2627
2628 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2629 where
2630 Self: Sized + 'e,
2631 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2632 #db: 'e,
2633 {
2634 #query_ident::fetch_one(self, executor)
2635 }
2636
2637 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2638 where
2639 Self: Sized + 'e,
2640 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2641 #db: 'e,
2642 {
2643 #query_ident::fetch_optional(self, executor)
2644 }
2645
2646 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2647 where
2648 Self: Sized + 'e,
2649 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2650 #db: 'e,
2651 {
2652 #query_ident::execute(self, executor)
2653 }
2654 }
2655 }
2656 };
2657
2658 generated_query_defs.push(quote! {
2659 struct #query_ident<'args> {
2660 inner: sqlx::QueryBuilder<'args, #db>,
2661 }
2662
2663 impl<'args> #query_ident<'args> {
2664 #exec_methods
2665 }
2666
2667 #trait_impl
2668 });
2669
2670 generated_query_values.push(quote! {
2671 #( #runtime_param_bindings )*
2672 #( #flag_bindings )*
2673 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2674 #( #runtime_steps )*
2675 let #query_value_ident = #query_ident { inner: __builder };
2676 });
2677
2678 if let Some(key) = result_key {
2679 let method_ident = format_ident!("{}", key);
2680 group_field_defs.push(quote! {
2681 #method_ident: #query_ident<'args>
2682 });
2683 group_field_tys.push(quote! { #query_ident<'args> });
2684 group_method_defs.push(quote! {
2685 pub fn #method_ident(self) -> #query_ident<'args> {
2686 self.#method_ident
2687 }
2688 });
2689
2690 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2691 group_trait_impls.push(quote! {
2692 struct #key_ty_ident;
2693
2694 impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2695 type Query = #query_ident<'args>;
2696
2697 fn get(self, _: #key_ty_ident) -> Self::Query {
2698 self.#method_ident
2699 }
2700 }
2701 });
2702 group_field_idents.push(method_ident);
2703 }
2704 }
2705
2706 let validator_tokens = quote! {
2708 let _sql_forge_validator = || {
2709 #( #validator_param_bindings )*
2710 #( #grouped_validator_invocations )*
2711 };
2712 };
2713
2714 if !is_grouped_result {
2715 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2716 return quote! {
2717 {
2718 #validator_tokens
2719 #( #generated_query_defs )*
2720 #( #generated_query_values )*
2721 #single_query_value_ident
2722 }
2723 }
2724 .into();
2725 }
2726
2727 let group_field_inits: Vec<TokenStream2> = result_cases
2728 .iter()
2729 .filter_map(|(key, _, _)| key.as_ref())
2730 .map(|key| {
2731 let method_ident = format_ident!("{}", key);
2732 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2733 quote! { #method_ident: #query_value_ident }
2734 })
2735 .collect();
2736
2737 quote! {
2738 {
2739 #validator_tokens
2740
2741 #( #generated_query_defs )*
2742 #( #generated_query_values )*
2743
2744 struct __SqlForgeQueryGroup<'args> {
2745 #( #group_field_defs, )*
2746 }
2747
2748 impl<'args> __SqlForgeQueryGroup<'args> {
2749 #( #group_method_defs )*
2750
2751 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2752 ( #( self.#group_field_idents ),* )
2753 }
2754 }
2755
2756 impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2757 type Db = #db;
2758 }
2759
2760 #( #group_trait_impls )*
2761
2762 __SqlForgeQueryGroup {
2763 #( #group_field_inits, )*
2764 }
2765 }
2766 }
2767 .into()
2768}
2769
2770#[proc_macro]
2784pub fn db_type(input: TokenStream) -> TokenStream {
2785 if !input.is_empty() {
2786 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2787 .to_compile_error()
2788 .into();
2789 }
2790
2791 match resolve_db_from_env() {
2792 Ok(db) => quote! { #db }.into(),
2793 Err(msg) => syn::Error::new(Span::call_site(), msg)
2794 .to_compile_error()
2795 .into(),
2796 }
2797}
2798
2799#[proc_macro_attribute]
2814pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2815 let input: ItemStruct = match syn::parse(item) {
2816 Ok(v) => v,
2817 Err(err) => return err.to_compile_error().into(),
2818 };
2819
2820 let struct_name = &input.ident;
2821 let inner_type = match &input.fields {
2822 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2823 _ => {
2824 return syn::Error::new(
2825 input.span(),
2826 "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2827 )
2828 .to_compile_error()
2829 .into();
2830 }
2831 };
2832
2833 let attrs = input.attrs;
2834 let generics = &input.generics;
2835 let vis = &input.vis;
2836 let struct_token = input.struct_token;
2837 let semi_token = input.semi_token;
2838 let fields = &input.fields;
2839
2840 let expanded = quote! {
2841 #( #attrs )*
2842 #[derive(sqlx::Type)]
2843 #[sqlx(transparent)]
2844 #vis #struct_token #struct_name #generics #fields #semi_token
2845
2846 impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2847 fn sql_forge_validator_value(&self) -> #inner_type {
2848 self.0.clone()
2849 }
2850 }
2851 };
2852
2853 expanded.into()
2854}