1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17 Stmt, Token, Type,
18};
19
20mod kw {
25 syn::custom_keyword!(scalar);
26}
27
28#[derive(Clone)]
30struct ParamAssign {
31 name: Ident,
32 expr: Expr,
33}
34
35impl Parse for ParamAssign {
36 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37 input.parse::<Token![:]>()?;
38 let name: Ident = input.parse()?;
39 input.parse::<Token![=]>()?;
40 let expr: Expr = input.parse()?;
41 Ok(Self { name, expr })
42 }
43}
44
45#[derive(Clone)]
47struct SectionFragment {
48 sql: String,
49 span: Span,
50 params: ParamsSource,
51}
52
53#[derive(Clone)]
55struct SectionMatchArm {
56 pat: Pat,
57 guard: Option<Expr>,
58 value: SectionValue,
59}
60
61#[derive(Clone)]
63enum SectionValue {
64 Single(SectionFragment),
66 Grouped(Vec<SectionValue>),
68 Match {
70 expr: Expr,
71 arms: Vec<SectionMatchArm>,
72 },
73}
74
75#[derive(Clone)]
77struct SectionAssign {
78 names: Vec<Ident>,
79 value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 input.parse::<Token![#]>()?;
85
86 let names = if input.peek(syn::token::Paren) {
88 let content;
89 syn::parenthesized!(content in input);
90 let mut out = Vec::new();
91 while !content.is_empty() {
92 out.push(content.parse::<Ident>()?);
93 if content.is_empty() {
94 break;
95 }
96 content.parse::<Token![,]>()?;
97 }
98 if out.is_empty() {
99 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100 }
101 out
102 } else {
103 vec![input.parse::<Ident>()?]
104 };
105
106 input.parse::<Token![=]>()?;
107 let value = parse_section_value(input, names.len())?;
108 Ok(Self { names, value })
109 }
110}
111
112struct SqlForgeInput {
114 db: Option<Type>,
115 result: ResultSpec,
116 force_scalar: bool,
117 sql: SqlTemplate,
118 params: ParamsSource,
119 sections: Vec<SectionAssign>,
120 batch: Option<Expr>,
121}
122
123#[derive(Clone)]
125struct ResultAssign {
126 name: Ident,
127 model: Type,
128 force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133 None,
135 Single(Box<Type>),
137 Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143 None,
144 Map(Vec<ParamAssign>),
146 Struct(Box<Expr>),
148}
149
150enum SqlTemplate {
152 Literal(LitStr),
153}
154
155impl SqlTemplate {
156 fn span(&self) -> Span {
157 match self {
158 Self::Literal(lit) => lit.span(),
159 }
160 }
161
162 fn into_segments(self) -> Result<Vec<Segment>, String> {
164 match self {
165 Self::Literal(lit) => parse_literal_segments(&lit.value()),
166 }
167 }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171 if input.peek(LitStr) {
172 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173 } else {
174 Err(input.error("sql_forge!: SQL template must be a string literal"))
175 }
176}
177
178#[derive(Clone)]
180enum Segment {
181 Text(String),
183 Section { name: String },
185 Batch { parts: Vec<TextPart> },
187}
188
189#[derive(Clone)]
191enum TextPart {
192 Lit(String),
194 Param { name: String, is_list: bool },
196}
197
198enum MapKind {
204 Results,
205 Params,
206 Sections,
207}
208
209fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
215 let fork = input.fork();
216 let content;
217 syn::parenthesized!(content in fork);
218
219 if content.is_empty() {
220 return Err(input.error("sql_forge!: map argument cannot be empty"));
221 }
222
223 if content.peek(Token![>]) {
224 Ok(Some(MapKind::Results))
225 } else if content.peek(Token![:]) {
226 Ok(Some(MapKind::Params))
227 } else if content.peek(Token![#]) {
228 Ok(Some(MapKind::Sections))
229 } else {
230 Ok(None)
231 }
232}
233
234impl Parse for ResultAssign {
235 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
236 input.parse::<Token![>]>()?;
237 let name: Ident = input.parse()?;
238 input.parse::<Token![=]>()?;
239 let (force_scalar, model) = if input.peek(kw::scalar) {
240 input.parse::<kw::scalar>()?;
241 (true, input.parse::<Type>()?)
242 } else {
243 (false, input.parse::<Type>()?)
244 };
245 Ok(Self {
246 name,
247 model,
248 force_scalar,
249 })
250 }
251}
252
253fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
255 let content;
256 syn::parenthesized!(content in input);
257
258 let mut results = Vec::new();
259 while !content.is_empty() {
260 results.push(content.parse::<ResultAssign>()?);
261 if content.is_empty() {
262 break;
263 }
264 content.parse::<Token![,]>()?;
265 }
266
267 if results.is_empty() {
268 return Err(input.error("sql_forge!: result map cannot be empty"));
269 }
270
271 Ok(results)
272}
273
274fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
276 let content;
277 syn::parenthesized!(content in input);
278
279 let mut params = Vec::new();
280 while !content.is_empty() {
281 params.push(content.parse::<ParamAssign>()?);
282 if content.is_empty() {
283 break;
284 }
285 content.parse::<Token![,]>()?;
286 }
287
288 Ok(params)
289}
290
291fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
293 let content;
294 syn::parenthesized!(content in input);
295
296 let mut sections = Vec::new();
297 while !content.is_empty() {
298 sections.push(content.parse::<SectionAssign>()?);
299 if content.is_empty() {
300 break;
301 }
302 content.parse::<Token![,]>()?;
303 }
304
305 Ok(sections)
306}
307
308fn parse_params_source_expr(input: ParseStream<'_>) -> syn::Result<ParamsSource> {
310 if input.peek(syn::token::Paren) {
311 match detect_parenthesized_map_kind(input)? {
312 Some(MapKind::Results) => Err(input
313 .error("sql_forge!: result maps are only allowed as the macro result argument")),
314 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
315 Some(MapKind::Sections) => Err(input.error(
316 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317 )),
318 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319 }
320 } else {
321 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322 }
323}
324
325fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327 if input.peek(syn::token::Paren) {
328 let fork = input.fork();
329 let content;
330 syn::parenthesized!(content in fork);
331
332 if let Ok(first_expr) = content.parse::<Expr>() {
333 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334 let _ = parse_params_source_expr(&content)?;
335 if content.peek(Token![,]) {
336 content.parse::<Token![,]>()?;
337 }
338 if content.is_empty() {
339 let content;
340 syn::parenthesized!(content in input);
341 let first_expr: Expr = content.parse()?;
342 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343 input.error("sql_forge!: section tuple must start with a string literal")
344 })?;
345 let span = first_expr.span();
346 content.parse::<Token![,]>()?;
347 let params = parse_params_source_expr(&content)?;
348 if content.peek(Token![,]) {
349 content.parse::<Token![,]>()?;
350 }
351 if !content.is_empty() {
352 return Err(content.error(
353 "sql_forge!: unexpected tokens after section-local parameter source",
354 ));
355 }
356 return Ok(SectionFragment { sql, span, params });
357 }
358 }
359 }
360 }
361
362 let expr: Expr = input.parse()?;
363 let sql = extract_lit_str(&expr).ok_or_else(|| {
364 input
365 .error("sql_forge!: section values must be string literals or (string literal, params)")
366 })?;
367 Ok(SectionFragment {
368 sql,
369 span: expr.span(),
370 params: ParamsSource::None,
371 })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
376 if input.peek(Token![match]) {
377 input.parse::<Token![match]>()?;
378 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
379 let content;
380 syn::braced!(content in input);
381 let mut arms = Vec::new();
382 while !content.is_empty() {
383 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
384 let guard = if content.peek(Token![if]) {
385 content.parse::<Token![if]>()?;
386 Some(content.parse::<Expr>()?)
387 } else {
388 None
389 };
390 content.parse::<Token![=>]>()?;
391 let value = parse_section_value(&content, width)?;
393 if content.peek(Token![,]) {
394 content.parse::<Token![,]>()?;
395 }
396 arms.push(SectionMatchArm { pat, guard, value });
397 }
398 return Ok(SectionValue::Match { expr, arms });
399 }
400
401 if width == 1 {
402 return Ok(SectionValue::Single(parse_section_fragment(input)?));
403 }
404
405 let content;
406 syn::parenthesized!(content in input);
407 let mut items = Vec::new();
408 while !content.is_empty() {
409 items.push(parse_section_value(&content, 1)?);
412 if content.is_empty() {
413 break;
414 }
415 content.parse::<Token![,]>()?;
416 }
417
418 if items.len() != width {
419 return Err(input.error(format!(
420 "sql_forge!: grouped section value must provide exactly {} items",
421 width,
422 )));
423 }
424
425 Ok(SectionValue::Grouped(items))
426}
427
428impl Parse for SqlForgeInput {
433 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
434 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
435 let sql = parse_sql_template(input)?;
436 (None, ResultSpec::None, false, sql)
437 } else if input.peek(kw::scalar) {
438 input.parse::<kw::scalar>()?;
439 let model: Type = input.parse()?;
440 input.parse::<Token![,]>()?;
441 let sql = parse_sql_template(input)?;
442 (None, ResultSpec::Single(Box::new(model)), true, sql)
443 } else if input.peek(syn::token::Paren) {
444 let result_map_kind = detect_parenthesized_map_kind(input)?;
445 match result_map_kind {
446 Some(MapKind::Results) => {
447 let result = ResultSpec::Group(parse_result_map(input)?);
448 input.parse::<Token![,]>()?;
449 let sql = parse_sql_template(input)?;
450 (None, result, false, sql)
451 }
452 _ => {
453 return Err(input.error(
454 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
455 ));
456 }
457 }
458 } else {
459 let first_ty: Type = input.parse()?;
460 input.parse::<Token![,]>()?;
461
462 if input.peek(LitStr) && is_db_type(&first_ty) {
463 let db = first_ty;
464 let sql = parse_sql_template(input)?;
465 (Some(db), ResultSpec::None, false, sql)
466 } else if input.peek(LitStr) {
467 let model = first_ty;
468 let sql = parse_sql_template(input)?;
469 (None, ResultSpec::Single(Box::new(model)), false, sql)
470 } else if input.peek(kw::scalar) {
471 input.parse::<kw::scalar>()?;
472 let model: Type = input.parse()?;
473 input.parse::<Token![,]>()?;
474 let sql = parse_sql_template(input)?;
475 (
476 Some(first_ty),
477 ResultSpec::Single(Box::new(model)),
478 true,
479 sql,
480 )
481 } else if input.peek(syn::token::Paren)
482 && matches!(
483 detect_parenthesized_map_kind(input)?,
484 Some(MapKind::Results)
485 )
486 {
487 let result = ResultSpec::Group(parse_result_map(input)?);
488 input.parse::<Token![,]>()?;
489 let sql = parse_sql_template(input)?;
490 (Some(first_ty), result, false, sql)
491 } else {
492 let db = Some(first_ty);
493 let model: Type = input.parse()?;
494 input.parse::<Token![,]>()?;
495 let sql = parse_sql_template(input)?;
496 (db, ResultSpec::Single(Box::new(model)), false, sql)
497 }
498 };
499
500 let mut batch = None;
501 let mut params = ParamsSource::None;
502 let mut sections = Vec::new();
503 let mut seen_params = false;
504 let mut seen_sections = false;
505
506 if input.parse::<Token![,]>().is_ok() {
507 while !input.is_empty() {
508 if input.peek(Token![..]) {
509 if batch.is_some() {
510 return Err(
511 input.error("sql_forge!: only one batch source argument is allowed")
512 );
513 }
514 input.parse::<Token![..]>()?;
515 batch = Some(input.parse::<Expr>()?);
516 } else if input.peek(syn::token::Paren) {
517 match detect_parenthesized_map_kind(input)? {
518 Some(MapKind::Results) => {
519 return Err(input.error(
520 "sql_forge!: result maps are only allowed as the macro result argument",
521 ));
522 }
523 Some(MapKind::Params) => {
524 if seen_params {
525 return Err(
526 input.error("sql_forge!: only one parameter source is allowed")
527 );
528 }
529 params = ParamsSource::Map(parse_param_map(input)?);
530 seen_params = true;
531 }
532 Some(MapKind::Sections) => {
533 if seen_sections {
534 return Err(
535 input.error("sql_forge!: duplicate section map argument")
536 );
537 }
538 sections = parse_section_map(input)?;
539 seen_sections = true;
540 }
541 None => {
542 if seen_params {
543 return Err(
544 input.error("sql_forge!: only one parameter source is allowed")
545 );
546 }
547 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548 seen_params = true;
549 }
550 }
551 } else {
552 if seen_params {
553 return Err(input.error("sql_forge!: only one parameter source is allowed"));
554 }
555 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
556 seen_params = true;
557 }
558
559 if input.parse::<Token![,]>().is_ok() {
560 continue;
561 }
562 break;
563 }
564 }
565
566 if !input.is_empty() {
567 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
568 }
569
570 Ok(Self {
571 db,
572 result,
573 force_scalar,
574 sql,
575 params,
576 sections,
577 batch,
578 })
579 }
580}
581
582fn resolve_db_from_env() -> Result<Type, String> {
588 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
589 return syn::parse_str::<Type>(&val).map_err(|err| {
590 format!(
591 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
592 val, err
593 )
594 });
595 }
596
597 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
598 Ok(d) => d,
599 Err(_) => {
600 return Err(
601 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
602 or configure [package.metadata.sql_forge] in Cargo.toml"
603 .to_string(),
604 );
605 }
606 };
607 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
608
609 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
610 format!(
611 "sql_forge!: failed to read {}: {}",
612 manifest_path.display(),
613 err
614 )
615 })?;
616
617 let value: toml::Value = toml::from_str(&cargo_toml)
618 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
619
620 let db_str = value
621 .get("package")
622 .and_then(|v| v.get("metadata"))
623 .and_then(|v| v.get("sql_forge"))
624 .and_then(|v| v.get("db"))
625 .and_then(|v| v.as_str())
626 .ok_or({
627 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
628 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
629 })?;
630
631 syn::parse_str::<Type>(db_str).map_err(|err| {
632 format!(
633 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
634 db_str, err
635 )
636 })
637}
638
639fn uses_dollar_params(db: &Type) -> bool {
641 let Type::Path(type_path) = db else {
642 return false;
643 };
644 type_path
645 .path
646 .segments
647 .last()
648 .is_some_and(|s| s.ident == "Postgres")
649}
650
651fn is_db_type(ty: &Type) -> bool {
653 let Type::Path(type_path) = ty else {
654 return false;
655 };
656 if type_path.qself.is_some() {
657 return false;
658 }
659 let segs = &type_path.path.segments;
660 if segs.len() != 2 {
661 return false;
662 }
663 segs[0].ident == "sqlx"
664 && ["MySql", "Postgres", "Sqlite"].contains(&segs[1].ident.to_string().as_str())
665}
666
667fn is_builtin_scalar_type(ty: &Type) -> bool {
669 let Type::Path(type_path) = ty else {
670 return false;
671 };
672
673 if type_path.qself.is_some()
674 || type_path.path.leading_colon.is_some()
675 || type_path.path.segments.len() != 1
676 {
677 return false;
678 }
679
680 let ident = &type_path.path.segments[0].ident;
681 ident == "i8"
682 || ident == "i16"
683 || ident == "i32"
684 || ident == "i64"
685 || ident == "isize"
686 || ident == "u8"
687 || ident == "u16"
688 || ident == "u32"
689 || ident == "u64"
690 || ident == "usize"
691 || ident == "f32"
692 || ident == "f64"
693 || ident == "bool"
694 || ident == "String"
695}
696
697fn scalar_output_type(model: &Type) -> Option<&Type> {
699 if is_builtin_scalar_type(model) {
700 return Some(model);
701 }
702 None
703}
704
705fn push_text_segment(out: &mut Vec<Segment>, text: String) {
707 if text.is_empty() {
708 return;
709 }
710 match out.last_mut() {
711 Some(Segment::Text(existing)) => existing.push_str(&text),
712 _ => out.push(Segment::Text(text)),
713 }
714}
715
716fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
718 let mut out = Vec::new();
719 let mut text = String::new();
720 let mut chars = sql.chars().peekable();
721
722 while let Some(ch) = chars.next() {
723 if ch != '{' {
724 text.push(ch);
725 continue;
726 }
727
728 if chars.peek() == Some(&'(') {
729 push_text_segment(&mut out, std::mem::take(&mut text));
730
731 let mut paren_depth = 0u32;
732 let mut content = String::new();
733 let mut found_close = false;
734 for ch in chars.by_ref() {
735 if ch == '{' {
736 return Err(
737 "sql_forge!: nested braces not allowed inside batch section".to_string()
738 );
739 }
740 if ch == '}' {
741 if paren_depth != 0 {
742 return Err(
743 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
744 .to_string(),
745 );
746 }
747 found_close = true;
748 break;
749 }
750 if ch == '(' {
751 paren_depth += 1;
752 } else if ch == ')' {
753 if paren_depth == 0 {
754 return Err(
755 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
756 .to_string(),
757 );
758 }
759 paren_depth -= 1;
760 }
761 content.push(ch);
762 }
763 if !found_close {
764 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
765 }
766 let parts = parse_text_parts(&content);
767 for part in &parts {
768 if let TextPart::Param { is_list: true, .. } = part {
769 return Err(
770 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
771 batch sections; use plain parameters (:name) instead"
772 .to_string(),
773 );
774 }
775 }
776 out.push(Segment::Batch { parts });
777 continue;
778 }
779
780 if chars.peek() != Some(&'#') {
781 text.push(ch);
782 continue;
783 }
784
785 chars.next();
786 push_text_segment(&mut out, std::mem::take(&mut text));
787
788 let mut name = String::new();
789 loop {
790 let Some(next) = chars.next() else {
791 return Err("sql_forge!: section placeholder without closing }".to_string());
792 };
793 if next == '}' {
794 break;
795 }
796 name.push(next);
797 }
798
799 if name.is_empty() {
800 return Err("sql_forge!: empty section placeholder name".to_string());
801 }
802
803 out.push(Segment::Section { name });
804 }
805
806 push_text_segment(&mut out, text);
807 Ok(out)
808}
809
810fn is_ident_start(ch: char) -> bool {
816 ch == '_' || ch.is_ascii_alphabetic()
817}
818
819fn is_ident_continue(ch: char) -> bool {
821 is_ident_start(ch) || ch.is_ascii_digit()
822}
823
824fn sanitize_backticked_alias_ident(content: &str) -> String {
826 let mut split_at = content.len();
827 for (idx, ch) in content.char_indices() {
828 if ch == '!' || ch == '?' || ch == ':' {
829 split_at = idx;
830 break;
831 }
832 }
833
834 if split_at == content.len() {
835 return content.to_string();
836 }
837
838 let base = content[..split_at].trim_end();
839 if base.is_empty() {
840 content.to_string()
841 } else {
842 base.to_string()
843 }
844}
845
846fn sanitize_runtime_sql_text(text: &str) -> String {
848 let mut out = String::with_capacity(text.len());
849 let mut chars = text.chars().peekable();
850
851 while let Some(ch) = chars.next() {
852 if ch != '`' {
853 out.push(ch);
854 continue;
855 }
856
857 let mut content = String::new();
858 let mut closed = false;
859
860 for next in chars.by_ref() {
861 if next == '`' {
862 closed = true;
863 break;
864 }
865 content.push(next);
866 }
867
868 if closed {
869 out.push('`');
870 out.push_str(&sanitize_backticked_alias_ident(&content));
871 out.push('`');
872 } else {
873 out.push('`');
874 out.push_str(&content);
875 break;
876 }
877 }
878
879 out
880}
881
882fn parse_text_parts(text: &str) -> Vec<TextPart> {
884 let mut parts = Vec::new();
885 let mut last = 0usize;
886 let mut iter = text.char_indices().peekable();
887
888 while let Some((idx, ch)) = iter.next() {
889 if ch != ':' {
890 continue;
891 }
892
893 let Some(&(next_idx, next_ch)) = iter.peek() else {
894 continue;
895 };
896
897 if !is_ident_start(next_ch) {
898 continue;
899 }
900
901 if text[..idx].ends_with(':') {
902 continue;
903 }
904
905 if last < idx {
906 parts.push(TextPart::Lit(text[last..idx].to_string()));
907 }
908
909 iter.next();
910
911 let mut name = String::new();
912 name.push(next_ch);
913 let mut end = next_idx + next_ch.len_utf8();
914
915 while let Some(&(j, c)) = iter.peek() {
916 if is_ident_continue(c) {
917 name.push(c);
918 end = j + c.len_utf8();
919 iter.next();
920 } else {
921 break;
922 }
923 }
924
925 let mut is_list = false;
926 if text[end..].starts_with("[]") {
927 is_list = true;
928 end += 2;
929 }
930
931 parts.push(TextPart::Param { name, is_list });
932 last = end;
933 }
934
935 if last < text.len() {
936 parts.push(TextPart::Lit(text[last..].to_string()));
937 }
938
939 parts
940}
941
942#[allow(clippy::type_complexity)]
953fn render_validator_sql(
954 parts: &[TextPart],
955 use_dollar_params: bool,
956 param_offset: &mut usize,
957 list_count: usize,
958 batch_expr: Option<TokenStream2>,
959) -> Result<(String, Vec<(String, bool)>, Vec<TokenStream2>), TokenStream> {
960 let mut out_sql = String::new();
961 let mut occurrences = Vec::new();
962 let mut batch_args = Vec::new();
963
964 for part in parts {
965 match part {
966 TextPart::Lit(lit) => out_sql.push_str(lit),
967 TextPart::Param { name, is_list } => {
968 if let Some(ref batch_expr) = batch_expr {
969 if *is_list {
970 return Err(syn::Error::new(
971 Span::call_site(),
972 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
973 batch sections; use plain parameters (:name) instead",
974 )
975 .to_compile_error()
976 .into());
977 }
978 let field_ident = format_ident!("{}", name);
979 if use_dollar_params {
980 *param_offset += 1;
981 write!(out_sql, "${}", *param_offset).unwrap();
982 } else {
983 out_sql.push('?');
984 }
985 batch_args.push(quote! { #batch_expr[0].#field_ident });
986 } else if *is_list && list_count > 1 {
987 let slots: Vec<String> = if use_dollar_params {
988 (0..list_count)
989 .map(|i| format!("${}", *param_offset + i + 1))
990 .collect()
991 } else {
992 (0..list_count).map(|_| "?".to_string()).collect()
993 };
994 if use_dollar_params {
995 *param_offset += list_count;
996 }
997 out_sql.push_str(&slots.join(", "));
998 occurrences.push((name.clone(), *is_list));
999 } else {
1000 if use_dollar_params {
1001 *param_offset += 1;
1002 write!(out_sql, "${}", *param_offset).unwrap();
1003 } else {
1004 out_sql.push('?');
1005 }
1006 occurrences.push((name.clone(), *is_list));
1007 }
1008 }
1009 }
1010 }
1011
1012 Ok((out_sql, occurrences, batch_args))
1013}
1014
1015fn strip_expr(expr: &Expr) -> &Expr {
1038 match expr {
1039 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
1040 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
1041 Expr::Block(ExprBlock { block, .. }) => {
1042 if block.stmts.len() != 1 {
1043 return expr;
1044 }
1045 match &block.stmts[0] {
1046 Stmt::Expr(inner, None) => strip_expr(inner),
1047 _ => expr,
1048 }
1049 }
1050 _ => expr,
1051 }
1052}
1053
1054fn extract_lit_str(expr: &Expr) -> Option<String> {
1056 match strip_expr(expr) {
1057 Expr::Lit(ExprLit {
1058 lit: Lit::Str(lit), ..
1059 }) => Some(lit.value()),
1060 _ => None,
1061 }
1062}
1063
1064fn result_flag_ident(name: &str) -> syn::Ident {
1070 format_ident!("__sql_forge_result_flag_{}", name)
1071}
1072
1073fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
1077 fn walk(stream: TokenStream2) -> TokenStream2 {
1078 let mut out = TokenStream2::new();
1079 let iter = stream.into_iter().peekable();
1080
1081 for token in iter {
1082 match token {
1083 TokenTree::Group(group) => {
1084 if group.delimiter() == Delimiter::Brace {
1085 let mut inner = group.stream().into_iter();
1086 let first = inner.next();
1087 let second = inner.next();
1088 let third = inner.next();
1089
1090 if let (
1091 Some(TokenTree::Punct(p)),
1092 Some(TokenTree::Ident(name_ident)),
1093 None,
1094 ) = (first, second, third)
1095 {
1096 if p.as_char() == '>' {
1097 let ident = result_flag_ident(&name_ident.to_string());
1098 out.extend(std::iter::once(TokenTree::Ident(ident)));
1099 continue;
1100 }
1101 }
1102 }
1103
1104 let new_inner = walk(group.stream());
1105 let mut new_group = Group::new(group.delimiter(), new_inner);
1106 new_group.set_span(group.span());
1107 out.extend(std::iter::once(TokenTree::Group(new_group)));
1108 }
1109 other => out.extend(std::iter::once(other)),
1110 }
1111 }
1112
1113 out
1114 }
1115
1116 walk(input)
1117}
1118
1119fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1121 keys.iter()
1122 .map(|key| {
1123 let ident = result_flag_ident(key);
1124 let enabled = Some(key.as_str()) == active_key;
1125 quote! { let #ident: bool = #enabled; }
1126 })
1127 .collect()
1128}
1129
1130fn transpose_section_case_matrix(
1132 case_matrix: Vec<Vec<SectionFragment>>,
1133 width: usize,
1134) -> Result<Vec<Vec<SectionFragment>>, String> {
1135 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1136
1137 for row in case_matrix {
1138 if row.len() != width {
1139 return Err(
1140 "sql_forge!: grouped sections must return one item per section".to_string(),
1141 );
1142 }
1143 for (section_idx, fragment) in row.into_iter().enumerate() {
1144 per_section[section_idx].push(fragment);
1145 }
1146 }
1147
1148 Ok(per_section)
1149}
1150
1151fn collect_section_case_matrix(
1153 value: SectionValue,
1154 width: usize,
1155 active_key: Option<&str>,
1156) -> Result<Vec<Vec<SectionFragment>>, String> {
1157 match value {
1158 SectionValue::Single(fragment) => {
1159 if width != 1 {
1160 return Err(
1161 "sql_forge!: grouped sections must return one item per section".to_string(),
1162 );
1163 }
1164 Ok(vec![vec![fragment]])
1165 }
1166 SectionValue::Grouped(values) => {
1167 if values.len() != width {
1168 return Err(
1169 "sql_forge!: grouped sections must return one item per section".to_string(),
1170 );
1171 }
1172
1173 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1174 let mut nmax = 1usize;
1175
1176 for value in values {
1177 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1178 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1179 for mut row in item_matrix {
1180 let fragment = row.pop().ok_or_else(|| {
1181 "sql_forge!: grouped sections must return one item per section".to_string()
1182 })?;
1183 if !row.is_empty() {
1184 return Err(
1185 "sql_forge!: grouped sections must return one item per section"
1186 .to_string(),
1187 );
1188 }
1189 item_variants.push(fragment);
1190 }
1191 if item_variants.is_empty() {
1192 return Err("sql_forge!: section match must have at least one arm".to_string());
1193 }
1194 nmax = nmax.max(item_variants.len());
1195 variants_by_section.push(item_variants);
1196 }
1197
1198 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1199 for case_idx in 0..nmax {
1200 let mut row = Vec::<SectionFragment>::with_capacity(width);
1201 for variants in &variants_by_section {
1202 row.push(variants[case_idx % variants.len()].clone());
1203 }
1204 case_matrix.push(row);
1205 }
1206
1207 Ok(case_matrix)
1208 }
1209 SectionValue::Match { expr, arms } => {
1210 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1211
1212 if let Some(key) = expr_result_flag_key(&expr) {
1213 let target = active_key == Some(key.as_str());
1214 for arm in arms {
1215 if arm.guard.is_none() {
1216 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1217 continue;
1218 }
1219 }
1220 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1221 wrap_section_case_matrix_for_match_arm(
1222 &mut arm_cases,
1223 &expr,
1224 &arm.pat,
1225 arm.guard.as_ref(),
1226 );
1227 case_matrix.extend(arm_cases);
1228 }
1229 } else {
1230 for arm in arms {
1231 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1232 wrap_section_case_matrix_for_match_arm(
1233 &mut arm_cases,
1234 &expr,
1235 &arm.pat,
1236 arm.guard.as_ref(),
1237 );
1238 case_matrix.extend(arm_cases);
1239 }
1240 }
1241
1242 if case_matrix.is_empty() {
1243 return Err("sql_forge!: section match must have at least one arm".to_string());
1244 }
1245
1246 Ok(case_matrix)
1247 }
1248 }
1249}
1250
1251fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1254 let match_expr = match_expr.clone();
1255 let pat = pat.clone();
1256 let pattern_binds_values = match &pat {
1257 Pat::Ident(_) => true,
1258 Pat::Or(pat_or) => pat_or
1259 .cases
1260 .iter()
1261 .any(|case| matches!(case, Pat::Ident(_))),
1262 Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1263 Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1264 Pat::Slice(pat_slice) => pat_slice
1265 .elems
1266 .iter()
1267 .any(|elem| matches!(elem, Pat::Ident(_))),
1268 Pat::Struct(pat_struct) => pat_struct
1269 .fields
1270 .iter()
1271 .any(|field| matches!(*field.pat, Pat::Ident(_))),
1272 Pat::Tuple(pat_tuple) => pat_tuple
1273 .elems
1274 .iter()
1275 .any(|elem| matches!(elem, Pat::Ident(_))),
1276 Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1277 .elems
1278 .iter()
1279 .any(|elem| matches!(elem, Pat::Ident(_))),
1280 Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1281 _ => false,
1282 };
1283
1284 if pattern_binds_values {
1285 let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1286 .into_iter()
1287 .map(|ident| quote! { let _ = &#ident; })
1288 .collect();
1289 if let Some(guard) = guard.cloned() {
1290 parse_quote! {
1291 match &(#match_expr) {
1292 #pat if #guard => { #( #pat_refs )* #expr },
1293 _ => unreachable!("sql_forge!: validator arm mismatch"),
1294 }
1295 }
1296 } else {
1297 parse_quote! {
1298 match &(#match_expr) {
1299 #pat => { #( #pat_refs )* #expr },
1300 _ => unreachable!("sql_forge!: validator arm mismatch"),
1301 }
1302 }
1303 }
1304 } else if let Some(guard) = guard.cloned() {
1305 parse_quote! {
1306 match &(#match_expr) {
1307 #pat if #guard => { &(#expr) },
1308 _ => unreachable!("sql_forge!: validator arm mismatch"),
1309 }
1310 }
1311 } else {
1312 parse_quote! {
1313 match &(#match_expr) {
1314 #pat => { &(#expr) },
1315 _ => unreachable!("sql_forge!: validator arm mismatch"),
1316 }
1317 }
1318 }
1319}
1320
1321fn wrap_params_source_for_match_arm(
1323 params: &mut ParamsSource,
1324 match_expr: &Expr,
1325 pat: &Pat,
1326 guard: Option<&Expr>,
1327) {
1328 match params {
1329 ParamsSource::None => {}
1330 ParamsSource::Map(entries) => {
1331 for entry in entries {
1332 entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1333 }
1334 }
1335 ParamsSource::Struct(expr) => {
1336 **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1337 }
1338 }
1339}
1340
1341fn wrap_section_case_matrix_for_match_arm(
1343 case_matrix: &mut [Vec<SectionFragment>],
1344 match_expr: &Expr,
1345 pat: &Pat,
1346 guard: Option<&Expr>,
1347) {
1348 for row in case_matrix {
1349 for fragment in row {
1350 wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1351 }
1352 }
1353}
1354
1355fn collect_section_variants(
1365 value: SectionValue,
1366 width: usize,
1367) -> Result<Vec<Vec<SectionFragment>>, String> {
1368 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1369}
1370
1371fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1373 match strip_expr(expr) {
1374 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1375 let name = path.path.segments[0].ident.to_string();
1376 name.strip_prefix("__sql_forge_result_flag_")
1377 .map(|v| v.to_string())
1378 }
1379 _ => None,
1380 }
1381}
1382
1383fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1385 match pat {
1386 Pat::Lit(expr_lit) => match &expr_lit.lit {
1387 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1388 _ => None,
1389 },
1390 Pat::Wild(_) => Some(true),
1391 _ => None,
1392 }
1393}
1394
1395fn collect_section_variants_for_result(
1400 value: SectionValue,
1401 width: usize,
1402 active_key: Option<&str>,
1403) -> Result<Vec<Vec<SectionFragment>>, String> {
1404 transpose_section_case_matrix(
1405 collect_section_case_matrix(value, width, active_key)?,
1406 width,
1407 )
1408}
1409
1410fn build_param_bindings(
1418 params: &ParamsSource,
1419 used_param_names: &[String],
1420 prefix: &str,
1421 for_validator: bool,
1422 enforce_usage_check: bool,
1423) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1424 let mut declared_params = HashMap::<String, syn::Ident>::new();
1425 let mut bindings = Vec::<TokenStream2>::new();
1426
1427 match params {
1428 ParamsSource::None => {}
1429 ParamsSource::Map(entries) => {
1430 for entry in entries {
1431 let key = entry.name.to_string();
1432 if declared_params.contains_key(&key) {
1433 return Err(syn::Error::new(
1434 entry.name.span(),
1435 "sql_forge!: duplicated parameter mapping",
1436 )
1437 .to_compile_error()
1438 .into());
1439 }
1440 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1441 return Err(syn::Error::new(
1442 entry.name.span(),
1443 format!(
1444 "sql_forge!: parameter :{} is unused in the SQL template",
1445 key,
1446 ),
1447 )
1448 .to_compile_error()
1449 .into());
1450 }
1451 let local_ident = format_ident!("__sql_forge_{}_{}", prefix, key);
1452 let expr = &entry.expr;
1453 if for_validator {
1454 bindings.push(quote! {
1455 let #local_ident = &(#expr);
1456 });
1457 } else {
1458 bindings.push(quote! {
1459 let #local_ident = #expr;
1460 });
1461 }
1462 declared_params.insert(key, local_ident);
1463 }
1464 }
1465 ParamsSource::Struct(expr) => {
1466 let source_ident = format_ident!("__sql_forge_source_{}", prefix);
1467 bindings.push(quote! {
1468 let #source_ident = &(#expr);
1469 });
1470 for name in used_param_names {
1471 let local_ident = format_ident!("__sql_forge_{}_{}", prefix, name);
1472 let field_ident = format_ident!("{}", name);
1473 if for_validator {
1474 bindings.push(quote! {
1475 let #local_ident = &#source_ident.#field_ident;
1476 });
1477 } else {
1478 bindings.push(quote! {
1479 let #local_ident = #source_ident.#field_ident;
1480 });
1481 }
1482 declared_params.insert(name.to_string(), local_ident);
1483 }
1484 }
1485 }
1486
1487 Ok((declared_params, bindings))
1488}
1489
1490struct ValidatorRenderContext<'a> {
1491 params: &'a HashMap<String, syn::Ident>,
1492 use_dollar_params: bool,
1493 sql_span: Span,
1494 list_count: usize,
1495}
1496
1497fn render_validator_args(
1502 sql: &str,
1503 param_offset: &mut usize,
1504 arg_index: &mut usize,
1505 context: &ValidatorRenderContext<'_>,
1506) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1507 let parts = parse_text_parts(sql);
1508 let (rendered_sql, occurrences, _batch_args) = render_validator_sql(
1509 &parts,
1510 context.use_dollar_params,
1511 param_offset,
1512 context.list_count,
1513 None,
1514 )?;
1515
1516 let mut setup = Vec::<TokenStream2>::new();
1517 let mut args = Vec::<TokenStream2>::new();
1518
1519 for (name, is_list) in occurrences {
1520 let Some(local_ident) = context.params.get(&name) else {
1521 return Err(syn::Error::new(
1522 context.sql_span,
1523 format!("sql_forge!: parameter :{} has no mapping", name),
1524 )
1525 .to_compile_error()
1526 .into());
1527 };
1528
1529 if is_list {
1530 for _ in 0..context.list_count {
1531 let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1532 *arg_index += 1;
1533 if context.use_dollar_params {
1534 setup.push(quote! {
1535 let #value_ident = sql_forge::sql_forge_validator_value(
1536 (#local_ident)
1537 .as_slice()
1538 .first()
1539 .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1540 );
1541 });
1542 } else {
1543 setup.push(quote! {
1544 let #value_ident = (#local_ident)
1545 .as_slice()
1546 .first()
1547 .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1548 });
1549 }
1550 args.push(quote! { #value_ident });
1551 }
1552 } else {
1553 let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1554 *arg_index += 1;
1555 if context.use_dollar_params {
1556 setup.push(quote! {
1557 let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1558 });
1559 } else {
1560 setup.push(quote! {
1561 let #value_ident = #local_ident;
1562 });
1563 }
1564 args.push(quote! { #value_ident });
1565 }
1566 }
1567
1568 Ok((rendered_sql, setup, args))
1569}
1570
1571fn render_runtime_fragment(
1578 fragment: &SectionFragment,
1579 local_params: &HashMap<String, syn::Ident>,
1580) -> Result<TokenStream2, TokenStream> {
1581 let mut steps = Vec::<TokenStream2>::new();
1582
1583 for part in parse_text_parts(&fragment.sql) {
1584 match part {
1585 TextPart::Lit(lit) => {
1586 let lit_str = LitStr::new(&lit, fragment.span);
1587 steps.push(quote! { __builder.push(#lit_str); });
1588 }
1589 TextPart::Param { name, is_list } => {
1590 let Some(local_ident) = local_params.get(&name) else {
1591 return Err(syn::Error::new(
1592 fragment.span,
1593 format!("sql_forge!: parameter :{} has no mapping", name),
1594 )
1595 .to_compile_error()
1596 .into());
1597 };
1598
1599 if is_list {
1600 steps.push(quote! {
1601 let __sql_forge_values = #local_ident;
1602 let mut __separated = __builder.separated(", ");
1603 for __value in __sql_forge_values {
1604 __separated.push_bind(__value);
1605 }
1606 });
1607 } else {
1608 steps.push(quote! {
1609 __builder.push_bind(#local_ident);
1610 });
1611 }
1612 }
1613 }
1614 }
1615
1616 Ok(quote! { #( #steps )* })
1617}
1618
1619fn is_pat_binding(ident: &Ident) -> bool {
1621 let name = ident.to_string();
1622 !name.is_empty()
1623 && name
1624 .chars()
1625 .next()
1626 .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1627}
1628
1629fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1631 let mut names = Vec::new();
1632 fn walk(p: &Pat, names: &mut Vec<Ident>) {
1633 match p {
1634 Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1635 Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1636 Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1637 Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1638 Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1639 Pat::Paren(pp) => walk(&pp.pat, names),
1640 Pat::Reference(pr) => walk(&pr.pat, names),
1641 Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1642 Pat::Type(pt) => walk(&pt.pat, names),
1643 _ => {}
1644 }
1645 }
1646 walk(pat, &mut names);
1647 names
1648}
1649
1650fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1652 match value {
1653 SectionValue::Single(f) => {
1654 if collect_used_param_names_in_sql(&f.sql)
1655 .iter()
1656 .any(|n| n == name)
1657 {
1658 return true;
1659 }
1660 if let ParamsSource::Map(entries) = &f.params {
1661 for e in entries {
1662 let expr = &e.expr;
1663 let expr_str = quote! { #expr }.to_string();
1664 if expr_str.trim() == name {
1665 return true;
1666 }
1667 }
1668 }
1669 false
1670 }
1671 SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1672 SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1673 let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1674 .into_iter()
1675 .map(|i| i.to_string())
1676 .collect();
1677 if pat_vars.contains(name) {
1678 false
1679 } else {
1680 section_value_refers_to(&arm.value, name)
1681 }
1682 }),
1683 }
1684}
1685
1686fn build_section_runtime_action(
1688 value: &SectionValue,
1689 section_idx: usize,
1690 prefix: &str,
1691) -> Result<TokenStream2, TokenStream> {
1692 match value {
1693 SectionValue::Single(fragment) => {
1694 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1695 let (local_params, bindings) =
1696 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1697 let body = render_runtime_fragment(fragment, &local_params)?;
1698 Ok(quote! {{ #( #bindings )* #body }})
1699 }
1700 SectionValue::Grouped(fragments) => build_section_runtime_action(
1701 &fragments[section_idx],
1702 0,
1703 &format!("{}_grouped_{}", prefix, section_idx),
1704 ),
1705 SectionValue::Match { expr, arms } => {
1706 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1707 .iter()
1708 .enumerate()
1709 .map(|(arm_idx, arm)| {
1710 let pat = &arm.pat;
1711 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1712 let body = build_section_runtime_action(
1713 &arm.value,
1714 section_idx,
1715 &format!("{}_{}", prefix, arm_idx),
1716 )?;
1717 let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1718 .into_iter()
1719 .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1720 .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1721 .collect();
1722 Ok::<TokenStream2, TokenStream>(quote! {
1723 #pat #guard_tokens => {
1724 #( #noop_refs )*
1725 #body
1726 }
1727 })
1728 })
1729 .collect();
1730 let arm_tokens = arm_tokens?;
1731 Ok(quote! {
1732 match #expr {
1733 #( #arm_tokens ),*
1734 }
1735 })
1736 }
1737 }
1738}
1739
1740fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1742 let mut names = Vec::new();
1743 let mut seen = HashSet::<String>::new();
1744
1745 for segment in segments {
1746 match segment {
1747 Segment::Text(text) => {
1748 for name in collect_used_param_names_in_sql(text) {
1749 if seen.insert(name.clone()) {
1750 names.push(name);
1751 }
1752 }
1753 }
1754 Segment::Batch { parts } => {
1755 for part in parts {
1756 if let TextPart::Param { name, .. } = part {
1757 if seen.insert(name.clone()) {
1758 names.push(name.clone());
1759 }
1760 }
1761 }
1762 }
1763 _ => {}
1764 }
1765 }
1766
1767 names
1768}
1769
1770fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1772 let mut names = Vec::new();
1773 let mut seen = HashSet::<String>::new();
1774 for part in parse_text_parts(sql) {
1775 if let TextPart::Param { name, .. } = part {
1776 if seen.insert(name.to_string()) {
1777 names.push(name);
1778 }
1779 }
1780 }
1781 names
1782}
1783
1784#[proc_macro]
2010#[allow(clippy::too_many_lines)]
2011pub fn sql_forge(input: TokenStream) -> TokenStream {
2012 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
2014 let SqlForgeInput {
2015 db,
2016 result,
2017 force_scalar,
2018 sql,
2019 params,
2020 sections,
2021 batch,
2022 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
2023 Ok(v) => v,
2024 Err(err) => return err.to_compile_error().into(),
2025 };
2026
2027 let db = match db {
2029 Some(db) => db,
2030 None => match resolve_db_from_env() {
2031 Ok(db) => db,
2032 Err(msg) => {
2033 return syn::Error::new(Span::call_site(), msg)
2034 .to_compile_error()
2035 .into();
2036 }
2037 },
2038 };
2039
2040 let use_dollar_params = uses_dollar_params(&db);
2041 let list_count: usize = 3;
2042
2043 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
2047 ResultSpec::None => {
2048 vec![(None, None, None)]
2049 }
2050 ResultSpec::Single(ref model) => {
2051 let model_ty = (**model).clone();
2052 let scalar = if force_scalar {
2053 Some(model_ty.clone())
2054 } else {
2055 scalar_output_type(model.as_ref()).cloned()
2056 };
2057 vec![(None, Some(model_ty), scalar)]
2058 }
2059 ResultSpec::Group(ref cases) => {
2060 if force_scalar {
2061 return syn::Error::new(
2062 Span::call_site(),
2063 "sql_forge!: scalar mode is not supported for grouped result maps",
2064 )
2065 .to_compile_error()
2066 .into();
2067 }
2068
2069 let mut out = Vec::new();
2070 let mut seen = HashSet::new();
2071 for case in cases {
2072 let key = case.name.to_string();
2073 if !seen.insert(key.clone()) {
2074 return syn::Error::new(
2075 case.name.span(),
2076 "sql_forge!: duplicated key in result map",
2077 )
2078 .to_compile_error()
2079 .into();
2080 }
2081
2082 let model = case.model.clone();
2083 let scalar = if case.force_scalar {
2084 Some(model.clone())
2085 } else {
2086 scalar_output_type(&case.model).cloned()
2087 };
2088 out.push((Some(key), Some(model), scalar));
2089 }
2090 out
2091 }
2092 };
2093 let group_result_keys: Vec<String> = result_cases
2094 .iter()
2095 .filter_map(|(key, _, _)| key.as_ref().cloned())
2096 .collect();
2097 let is_grouped_result = !group_result_keys.is_empty();
2098 let sql_span = sql.span();
2099
2100 let segments = match sql.into_segments() {
2102 Ok(segments) => segments,
2103 Err(msg) => {
2104 return syn::Error::new(sql_span, msg).to_compile_error().into();
2105 }
2106 };
2107
2108 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2109 match (&batch, has_batch_segment) {
2110 (None, true) => {
2111 return syn::Error::new(
2112 sql_span,
2113 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2114 was provided"
2115 )
2116 .to_compile_error()
2117 .into();
2118 }
2119 (Some(_), false) => {
2120 return syn::Error::new(
2121 sql_span,
2122 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2123 batch section",
2124 )
2125 .to_compile_error()
2126 .into();
2127 }
2128 _ => {}
2129 }
2130
2131 let used_param_names = collect_used_param_names(&segments);
2132
2133 let text_param_names: std::collections::HashSet<String> = segments
2141 .iter()
2142 .filter_map(|s| {
2143 if let Segment::Text(text) = s {
2144 Some(collect_used_param_names_in_sql(text).into_iter())
2145 } else {
2146 None
2147 }
2148 })
2149 .flatten()
2150 .collect();
2151 let top_level_used_names: Vec<String> = used_param_names
2152 .iter()
2153 .filter(|n| text_param_names.contains(*n))
2154 .cloned()
2155 .collect();
2156
2157 let (declared_params, validator_param_bindings) =
2159 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
2160 Ok(v) => v,
2161 Err(err) => return err,
2162 };
2163
2164 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2165
2166 for assign in §ions {
2168 let SectionAssign { names, value } = assign;
2169
2170 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2172 for (section_idx, name_ident) in names.iter().enumerate() {
2173 let name = name_ident.to_string();
2174 if runtime_section_actions.contains_key(&name) {
2175 return syn::Error::new(
2176 name_ident.span(),
2177 "sql_forge!: duplicated section mapping",
2178 )
2179 .to_compile_error()
2180 .into();
2181 }
2182 let action = match build_section_runtime_action(
2183 value,
2184 section_idx,
2185 &format!("section_{}", name),
2186 ) {
2187 Ok(action) => action,
2188 Err(err) => return err,
2189 };
2190 named_actions.push((name, action));
2191 }
2192
2193 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2195 return syn::Error::new(names[0].span(), msg)
2196 .to_compile_error()
2197 .into();
2198 }
2199
2200 for (name, action) in named_actions {
2201 runtime_section_actions.insert(name, action);
2202 }
2203 }
2204
2205 let sql_section_names: std::collections::HashSet<&str> = segments
2206 .iter()
2207 .filter_map(|seg| {
2208 if let Segment::Section { name } = seg {
2209 Some(name.as_str())
2210 } else {
2211 None
2212 }
2213 })
2214 .collect();
2215 for name in runtime_section_actions.keys() {
2216 if !sql_section_names.contains(name.as_str()) {
2217 return syn::Error::new(
2218 sql_span,
2219 format!(
2220 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2221 name, name,
2222 ),
2223 )
2224 .to_compile_error()
2225 .into();
2226 }
2227 }
2228
2229 let mut generated_query_defs = Vec::<TokenStream2>::new();
2231 let mut generated_query_values = Vec::<TokenStream2>::new();
2232 let mut group_field_defs = Vec::<TokenStream2>::new();
2233 let mut group_field_idents = Vec::<syn::Ident>::new();
2234 let mut group_field_tys = Vec::<TokenStream2>::new();
2235 let mut group_trait_impls = Vec::<TokenStream2>::new();
2236
2237 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2238
2239 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2240 let suffix = result_key.as_deref().unwrap_or("single");
2241 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2242 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2243
2244 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2245
2246 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2247 for assign in §ions {
2248 let SectionAssign { names, value } = assign;
2249 let variants_by_section = match collect_section_variants_for_result(
2250 value.clone(),
2251 names.len(),
2252 result_key.as_deref(),
2253 ) {
2254 Ok(v) => v,
2255 Err(msg) => {
2256 return syn::Error::new(names[0].span(), msg)
2257 .to_compile_error()
2258 .into();
2259 }
2260 };
2261
2262 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2263 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2264 }
2265 }
2266
2267 let mut nmax = 1usize;
2268 for segment in &segments {
2269 if let Segment::Section { name } = segment {
2270 if let Some(variants) = section_variants_for_validation.get(name) {
2271 if variants.is_empty() {
2272 return syn::Error::new(
2273 sql_span,
2274 format!("sql_forge!: section {{#{}}} has no possible variants", name),
2275 )
2276 .to_compile_error()
2277 .into();
2278 }
2279 nmax = nmax.max(variants.len());
2280 } else {
2281 return syn::Error::new(
2282 sql_span,
2283 format!("sql_forge!: section {{#{}}} has no mapping", name),
2284 )
2285 .to_compile_error()
2286 .into();
2287 }
2288 }
2289 }
2290
2291 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2292 for case_idx in 0..nmax {
2293 let mut sql_case = String::new();
2294 let mut case_setup = Vec::<TokenStream2>::new();
2295 let mut case_args = Vec::<TokenStream2>::new();
2296 let mut param_offset = 0usize;
2297 let mut arg_index = 0usize;
2298 let root_validator_context = ValidatorRenderContext {
2299 params: &declared_params,
2300 use_dollar_params,
2301 sql_span,
2302 list_count,
2303 };
2304
2305 for segment in &segments {
2306 match segment {
2307 Segment::Text(text) => {
2308 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2309 text,
2310 &mut param_offset,
2311 &mut arg_index,
2312 &root_validator_context,
2313 ) {
2314 Ok(value) => value,
2315 Err(err) => return err,
2316 };
2317 sql_case.push_str(&chunk_sql);
2318 case_setup.extend(chunk_setup);
2319 case_args.extend(chunk_args);
2320 }
2321 Segment::Section { name } => {
2322 let Some(variants) = section_variants_for_validation.get(name) else {
2323 return syn::Error::new(
2324 sql_span,
2325 format!("sql_forge!: section {{#{}}} has no mapping", name),
2326 )
2327 .to_compile_error()
2328 .into();
2329 };
2330
2331 let fragment = &variants[case_idx % variants.len()];
2332 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2333 let (local_params, bindings) = match build_param_bindings(
2334 &fragment.params,
2335 &used_param_names,
2336 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2337 true,
2338 true,
2339 ) {
2340 Ok(value) => value,
2341 Err(err) => return err,
2342 };
2343 let section_validator_context = ValidatorRenderContext {
2344 params: &local_params,
2345 use_dollar_params,
2346 sql_span: fragment.span,
2347 list_count,
2348 };
2349 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2350 &fragment.sql,
2351 &mut param_offset,
2352 &mut arg_index,
2353 §ion_validator_context,
2354 ) {
2355 Ok(value) => value,
2356 Err(err) => return err,
2357 };
2358 sql_case.push_str(&chunk_sql);
2359 case_setup.extend(bindings);
2360 case_setup.extend(chunk_setup);
2361 case_args.extend(chunk_args);
2362 }
2363 Segment::Batch { parts } => {
2364 let batch_ts = batch.as_ref().map(|e| quote! { #e });
2365 let mut first = true;
2366 for _ in 0..list_count {
2367 let sep = if first { "" } else { ", " };
2368 first = false;
2369 sql_case.push_str(sep);
2370 let (chunk_sql, _occurrences, chunk_args) = match render_validator_sql(
2371 parts,
2372 use_dollar_params,
2373 &mut param_offset,
2374 list_count,
2375 batch_ts.clone(),
2376 ) {
2377 Ok(value) => value,
2378 Err(err) => return err,
2379 };
2380 sql_case.push_str(&chunk_sql);
2381 case_args.extend(chunk_args);
2382 }
2383 }
2384 }
2385 }
2386
2387 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2388 }
2389
2390 let mut validator_invocations = Vec::<TokenStream2>::new();
2391 for (sql_lit, case_setup, args) in &validator_cases {
2392 if model_opt.is_none() {
2393 if args.is_empty() {
2394 validator_invocations.push(quote! {
2395 {
2396 #( #case_setup )*
2397 let _ = sqlx::query_scalar!(
2398 #sql_lit,
2399 );
2400 }
2401 });
2402 } else {
2403 validator_invocations.push(quote! {
2404 {
2405 #( #case_setup )*
2406 let _ = sqlx::query_scalar!(
2407 #sql_lit,
2408 #( #args ),*
2409 );
2410 }
2411 });
2412 }
2413 } else if let Some(scalar_ty) = scalar_model_ty {
2414 if args.is_empty() {
2415 validator_invocations.push(quote! {
2416 {
2417 #( #case_setup )*
2418 let _ = sqlx::query_scalar!(
2419 #sql_lit,
2420 );
2421 }
2422 });
2423 } else {
2424 validator_invocations.push(quote! {
2425 {
2426 #( #case_setup )*
2427 let _ = sqlx::query_scalar!(
2428 #sql_lit,
2429 #( #args ),*
2430 );
2431 }
2432 });
2433 }
2434 let _ = scalar_ty;
2435 } else if args.is_empty() {
2436 validator_invocations.push(quote! {
2437 {
2438 #( #case_setup )*
2439 let _ = sqlx::query_as!(
2440 __SqlForgeModel,
2441 #sql_lit,
2442 );
2443 }
2444 });
2445 } else {
2446 validator_invocations.push(quote! {
2447 {
2448 #( #case_setup )*
2449 let _ = sqlx::query_as!(
2450 __SqlForgeModel,
2451 #sql_lit,
2452 #( #args ),*
2453 );
2454 }
2455 });
2456 }
2457 }
2458
2459 let model_alias = if let Some(model) = model_opt {
2460 if scalar_model_ty.is_none() {
2461 quote! { type __SqlForgeModel = #model; }
2462 } else {
2463 quote! {}
2464 }
2465 } else {
2466 quote! {}
2467 };
2468 grouped_validator_invocations.push(quote! {
2469 {
2470 #( #flag_bindings )*
2471 #model_alias
2472 #( #validator_invocations )*
2473 }
2474 });
2475
2476 let (runtime_declared_params, runtime_param_bindings) =
2477 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2478 Ok(v) => v,
2479 Err(err) => return err,
2480 };
2481
2482 let mut runtime_steps = Vec::<TokenStream2>::new();
2483 for (seg_idx, segment) in segments.iter().enumerate() {
2484 match segment {
2485 Segment::Text(text) => {
2486 for part in parse_text_parts(text) {
2487 match part {
2488 TextPart::Lit(lit) => {
2489 let lit = sanitize_runtime_sql_text(&lit);
2490 let lit_str = LitStr::new(&lit, sql_span);
2491 runtime_steps.push(quote! {
2492 __builder.push(#lit_str);
2493 });
2494 }
2495 TextPart::Param { name, is_list } => {
2496 let Some(local_ident) = runtime_declared_params.get(&name) else {
2497 return syn::Error::new(
2498 sql_span,
2499 format!("sql_forge!: parameter :{} has no mapping", name),
2500 )
2501 .to_compile_error()
2502 .into();
2503 };
2504
2505 if is_list {
2506 runtime_steps.push(quote! {
2507 let __sql_forge_values = #local_ident;
2508 let mut __separated = __builder.separated(", ");
2509 for __value in __sql_forge_values {
2510 __separated.push_bind(__value);
2511 }
2512 });
2513 } else {
2514 runtime_steps.push(quote! {
2515 __builder.push_bind(#local_ident);
2516 });
2517 }
2518 }
2519 }
2520 }
2521 }
2522 Segment::Section { name } => {
2523 let Some(section_action) = runtime_section_actions.get(name) else {
2524 let _ = seg_idx;
2525 return syn::Error::new(
2526 sql_span,
2527 format!("sql_forge!: section {{#{}}} has no mapping", name),
2528 )
2529 .to_compile_error()
2530 .into();
2531 };
2532 runtime_steps.push(quote! {
2533 #section_action
2534 });
2535 }
2536 Segment::Batch { parts } => {
2537 if let Some(batch_expr) = &batch {
2538 let mut body = Vec::<TokenStream2>::new();
2539 for part in parts {
2540 match part {
2541 TextPart::Lit(lit) => {
2542 let lit_str = LitStr::new(lit, sql_span);
2543 body.push(quote! {
2544 __builder.push(#lit_str);
2545 });
2546 }
2547 TextPart::Param { name, .. } => {
2548 let field_ident = format_ident!("{}", name);
2549 body.push(quote! {
2550 __builder.push_bind(__item.#field_ident);
2551 });
2552 }
2553 }
2554 }
2555 runtime_steps.push(quote! {
2556 {
2557 let mut __first = true;
2558 for __item in #batch_expr {
2559 if !__first {
2560 __builder.push(", ");
2561 }
2562 __first = false;
2563 #( #body )*
2564 }
2565 }
2566 });
2567 }
2568 }
2569 }
2570 }
2571
2572 let exec_methods = if model_opt.is_none() {
2573 quote! {
2574 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2575 where
2576 E: sqlx::Executor<'e, Database = #db>,
2577 {
2578 self.inner.build().execute(executor).await
2579 }
2580 }
2581 } else if let Some(scalar_ty) = scalar_model_ty {
2582 quote! {
2583 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2584 where
2585 E: sqlx::Executor<'e, Database = #db>,
2586 {
2587 self.inner
2588 .build_query_scalar::<#scalar_ty>()
2589 .fetch_all(executor)
2590 .await
2591 }
2592
2593 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2594 where
2595 E: sqlx::Executor<'e, Database = #db>,
2596 {
2597 self.inner
2598 .build_query_scalar::<#scalar_ty>()
2599 .fetch_one(executor)
2600 .await
2601 }
2602
2603 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2604 where
2605 E: sqlx::Executor<'e, Database = #db>,
2606 {
2607 self.inner
2608 .build_query_scalar::<#scalar_ty>()
2609 .fetch_optional(executor)
2610 .await
2611 }
2612
2613 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2614 where
2615 E: sqlx::Executor<'e, Database = #db>,
2616 {
2617 self.inner.build().execute(executor).await
2618 }
2619 }
2620 } else {
2621 let model = model_opt.as_ref().unwrap();
2622 quote! {
2623 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2624 where
2625 E: sqlx::Executor<'e, Database = #db>,
2626 {
2627 self.inner.build_query_as::<#model>().fetch_all(executor).await
2628 }
2629
2630 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2631 where
2632 E: sqlx::Executor<'e, Database = #db>,
2633 {
2634 self.inner.build_query_as::<#model>().fetch_one(executor).await
2635 }
2636
2637 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2638 where
2639 E: sqlx::Executor<'e, Database = #db>,
2640 {
2641 self.inner
2642 .build_query_as::<#model>()
2643 .fetch_optional(executor)
2644 .await
2645 }
2646
2647 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2648 where
2649 E: sqlx::Executor<'e, Database = #db>,
2650 {
2651 self.inner.build().execute(executor).await
2652 }
2653 }
2654 };
2655
2656 let final_type: TokenStream2 = if let Some(model) = model_opt {
2657 if let Some(scalar_ty) = scalar_model_ty {
2658 quote! { #scalar_ty }
2659 } else {
2660 quote! { #model }
2661 }
2662 } else {
2663 quote! {}
2664 };
2665 let trait_impl = if model_opt.is_none() {
2666 quote! {
2667 impl sql_forge::SqlForgeQueryExecute
2668 for #query_ident
2669 {
2670 type Db = #db;
2671
2672 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2673 where
2674 Self: Sized + 'e,
2675 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2676 #db: 'e,
2677 {
2678 #query_ident::execute(self, executor)
2679 }
2680 }
2681 }
2682 } else {
2683 quote! {
2684 impl sql_forge::SqlForgeQuery<#final_type>
2685 for #query_ident
2686 {
2687 type Db = #db;
2688
2689 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2690 where
2691 Self: Sized + 'e,
2692 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2693 #db: 'e,
2694 {
2695 #query_ident::fetch_all(self, executor)
2696 }
2697
2698 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2699 where
2700 Self: Sized + 'e,
2701 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2702 #db: 'e,
2703 {
2704 #query_ident::fetch_one(self, executor)
2705 }
2706
2707 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2708 where
2709 Self: Sized + 'e,
2710 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2711 #db: 'e,
2712 {
2713 #query_ident::fetch_optional(self, executor)
2714 }
2715
2716 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2717 where
2718 Self: Sized + 'e,
2719 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2720 #db: 'e,
2721 {
2722 #query_ident::execute(self, executor)
2723 }
2724 }
2725 }
2726 };
2727
2728 generated_query_defs.push(quote! {
2729 struct #query_ident {
2730 inner: sqlx::QueryBuilder<#db>,
2731 }
2732
2733 impl #query_ident {
2734 #exec_methods
2735 }
2736
2737 #trait_impl
2738 });
2739
2740 generated_query_values.push(quote! {
2741 #( #runtime_param_bindings )*
2742 #( #flag_bindings )*
2743 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2744 #( #runtime_steps )*
2745 let #query_value_ident = #query_ident { inner: __builder };
2746 });
2747
2748 if let Some(key) = result_key {
2749 let method_ident = format_ident!("{}", key);
2750 group_field_defs.push(quote! {
2751 #method_ident: #query_ident
2752 });
2753 group_field_tys.push(quote! { #query_ident });
2754
2755 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2756 group_trait_impls.push(quote! {
2757 struct #key_ty_ident;
2758
2759 impl sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup {
2760 type Query = #query_ident;
2761
2762 fn get(self, _: #key_ty_ident) -> Self::Query {
2763 self.#method_ident
2764 }
2765 }
2766 });
2767 group_field_idents.push(method_ident);
2768 }
2769 }
2770
2771 let validator_tokens = quote! {
2773 let _sql_forge_validator = || {
2774 #( #validator_param_bindings )*
2775 #( #grouped_validator_invocations )*
2776 };
2777 };
2778
2779 if !is_grouped_result {
2780 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2781 return quote! {
2782 {
2783 #validator_tokens
2784 #( #generated_query_defs )*
2785 #( #generated_query_values )*
2786 #single_query_value_ident
2787 }
2788 }
2789 .into();
2790 }
2791
2792 let group_field_inits: Vec<TokenStream2> = result_cases
2793 .iter()
2794 .filter_map(|(key, _, _)| key.as_ref())
2795 .map(|key| {
2796 let method_ident = format_ident!("{}", key);
2797 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2798 quote! { #method_ident: #query_value_ident }
2799 })
2800 .collect();
2801
2802 quote! {
2803 {
2804 #validator_tokens
2805
2806 #( #generated_query_defs )*
2807 #( #generated_query_values )*
2808
2809 struct __SqlForgeQueryGroup {
2810 #( #group_field_defs, )*
2811 }
2812
2813 impl __SqlForgeQueryGroup {
2814 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2815 ( #( self.#group_field_idents ),* )
2816 }
2817 }
2818
2819 impl sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup {
2820 type Db = #db;
2821 }
2822
2823 #( #group_trait_impls )*
2824
2825 __SqlForgeQueryGroup {
2826 #( #group_field_inits, )*
2827 }
2828 }
2829 }
2830 .into()
2831}
2832
2833#[proc_macro]
2847pub fn db_type(input: TokenStream) -> TokenStream {
2848 if !input.is_empty() {
2849 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2850 .to_compile_error()
2851 .into();
2852 }
2853
2854 match resolve_db_from_env() {
2855 Ok(db) => quote! { #db }.into(),
2856 Err(msg) => syn::Error::new(Span::call_site(), msg)
2857 .to_compile_error()
2858 .into(),
2859 }
2860}
2861
2862#[proc_macro_attribute]
2877pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2878 let input: ItemStruct = match syn::parse(item) {
2879 Ok(v) => v,
2880 Err(err) => return err.to_compile_error().into(),
2881 };
2882
2883 let struct_name = &input.ident;
2884 let inner_type = match &input.fields {
2885 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2886 _ => {
2887 return syn::Error::new(
2888 input.span(),
2889 "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2890 )
2891 .to_compile_error()
2892 .into();
2893 }
2894 };
2895
2896 let attrs = input.attrs;
2897 let generics = &input.generics;
2898 let vis = &input.vis;
2899 let struct_token = input.struct_token;
2900 let semi_token = input.semi_token;
2901 let fields = &input.fields;
2902
2903 let expanded = quote! {
2904 #( #attrs )*
2905 #[derive(sqlx::Type)]
2906 #[sqlx(transparent)]
2907 #vis #struct_token #struct_name #generics #fields #semi_token
2908
2909 impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2910 fn sql_forge_validator_value(&self) -> #inner_type {
2911 self.0.clone()
2912 }
2913 }
2914 };
2915
2916 expanded.into()
2917}