1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Ident, Lit, LitStr, Pat, Stmt, Token, Type,
17};
18
19mod kw {
24 syn::custom_keyword!(scalar);
25}
26
27#[derive(Clone)]
29struct ParamAssign {
30 name: Ident,
31 expr: Expr,
32}
33
34impl Parse for ParamAssign {
35 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
36 input.parse::<Token![:]>()?;
37 let name: Ident = input.parse()?;
38 input.parse::<Token![=]>()?;
39 let expr: Expr = input.parse()?;
40 Ok(Self { name, expr })
41 }
42}
43
44#[derive(Clone)]
46struct SectionFragment {
47 sql: String,
48 span: Span,
49 params: ParamsSource,
50}
51
52#[derive(Clone)]
54struct SectionMatchArm {
55 pat: Pat,
56 guard: Option<Expr>,
57 value: SectionValue,
58}
59
60#[derive(Clone)]
62enum SectionValue {
63 Single(SectionFragment),
65 Grouped(Vec<SectionValue>),
67 Match {
69 expr: Expr,
70 arms: Vec<SectionMatchArm>,
71 },
72}
73
74#[derive(Clone)]
76struct SectionAssign {
77 names: Vec<Ident>,
78 value: SectionValue,
79}
80
81impl Parse for SectionAssign {
82 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
83 input.parse::<Token![#]>()?;
84
85 let names = if input.peek(syn::token::Paren) {
87 let content;
88 syn::parenthesized!(content in input);
89 let mut out = Vec::new();
90 while !content.is_empty() {
91 out.push(content.parse::<Ident>()?);
92 if content.is_empty() {
93 break;
94 }
95 content.parse::<Token![,]>()?;
96 }
97 if out.is_empty() {
98 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
99 }
100 out
101 } else {
102 vec![input.parse::<Ident>()?]
103 };
104
105 input.parse::<Token![=]>()?;
106 let value = parse_section_value(input, names.len())?;
107 Ok(Self { names, value })
108 }
109}
110
111struct SqlForgeInput {
113 db: Option<Type>,
114 result: ResultSpec,
115 force_scalar: bool,
116 sql: SqlTemplate,
117 params: ParamsSource,
118 sections: Vec<SectionAssign>,
119 batch: Option<Expr>,
120}
121
122#[derive(Clone)]
124struct ResultAssign {
125 name: Ident,
126 model: Type,
127 force_scalar: bool,
128}
129
130#[derive(Clone)]
131enum ResultSpec {
132 None,
134 Single(Box<Type>),
136 Group(Vec<ResultAssign>),
138}
139
140#[derive(Clone)]
141enum ParamsSource {
142 None,
143 Map(Vec<ParamAssign>),
145 Struct(Box<Expr>),
147}
148
149enum SqlTemplate {
151 Literal(LitStr),
152}
153
154impl SqlTemplate {
155 fn span(&self) -> Span {
156 match self {
157 Self::Literal(lit) => lit.span(),
158 }
159 }
160
161 fn into_segments(self) -> Result<Vec<Segment>, String> {
163 match self {
164 Self::Literal(lit) => parse_literal_segments(&lit.value()),
165 }
166 }
167}
168
169fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
170 if input.peek(LitStr) {
171 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
172 } else {
173 Err(input.error("sql_forge!: SQL template must be a string literal"))
174 }
175}
176
177#[derive(Clone)]
179enum Segment {
180 Text(String),
182 Section { name: String },
184 Batch { parts: Vec<TextPart> },
186}
187
188#[derive(Clone)]
190enum TextPart {
191 Lit(String),
193 Param { name: String, is_list: bool },
195}
196
197enum MapKind {
203 Results,
204 Params,
205 Sections,
206}
207
208fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
213 let fork = input.fork();
214 let content;
215 syn::parenthesized!(content in fork);
216
217 if content.is_empty() {
218 return Err(input.error("sql_forge!: map argument cannot be empty"));
219 }
220
221 if content.peek(Token![>]) {
222 Ok(Some(MapKind::Results))
223 } else if content.peek(Token![:]) {
224 Ok(Some(MapKind::Params))
225 } else if content.peek(Token![#]) {
226 Ok(Some(MapKind::Sections))
227 } else {
228 Ok(None)
229 }
230}
231
232impl Parse for ResultAssign {
233 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
234 input.parse::<Token![>]>()?;
235 let name: Ident = input.parse()?;
236 input.parse::<Token![=]>()?;
237 let (force_scalar, model) = if input.peek(kw::scalar) {
238 input.parse::<kw::scalar>()?;
239 (true, input.parse::<Type>()?)
240 } else {
241 (false, input.parse::<Type>()?)
242 };
243 Ok(Self {
244 name,
245 model,
246 force_scalar,
247 })
248 }
249}
250
251fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
252 let content;
253 syn::parenthesized!(content in input);
254
255 let mut results = Vec::new();
256 while !content.is_empty() {
257 results.push(content.parse::<ResultAssign>()?);
258 if content.is_empty() {
259 break;
260 }
261 content.parse::<Token![,]>()?;
262 }
263
264 if results.is_empty() {
265 return Err(input.error("sql_forge!: result map cannot be empty"));
266 }
267
268 Ok(results)
269}
270
271fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
272 let content;
273 syn::parenthesized!(content in input);
274
275 let mut params = Vec::new();
276 while !content.is_empty() {
277 params.push(content.parse::<ParamAssign>()?);
278 if content.is_empty() {
279 break;
280 }
281 content.parse::<Token![,]>()?;
282 }
283
284 Ok(params)
285}
286
287fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
288 let content;
289 syn::parenthesized!(content in input);
290
291 let mut sections = Vec::new();
292 while !content.is_empty() {
293 sections.push(content.parse::<SectionAssign>()?);
294 if content.is_empty() {
295 break;
296 }
297 content.parse::<Token![,]>()?;
298 }
299
300 Ok(sections)
301}
302
303fn parse_params_source_expr(
304 input: ParseStream<'_>,
305 allow_sections: bool,
306) -> syn::Result<ParamsSource> {
307 if input.peek(syn::token::Paren) {
308 match detect_parenthesized_map_kind(input)? {
309 Some(MapKind::Results) => Err(input
310 .error("sql_forge!: result maps are only allowed as the macro result argument")),
311 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
312 Some(MapKind::Sections) if allow_sections => {
313 Err(input.error("sql_forge!: section maps are not allowed here"))
314 }
315 Some(MapKind::Sections) => Err(input.error(
316 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317 )),
318 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319 }
320 } else {
321 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322 }
323}
324
325fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
326 if input.peek(syn::token::Paren) {
327 let fork = input.fork();
328 let content;
329 syn::parenthesized!(content in fork);
330
331 if let Ok(first_expr) = content.parse::<Expr>() {
332 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
333 let _ = parse_params_source_expr(&content, false)?;
334 if content.peek(Token![,]) {
335 content.parse::<Token![,]>()?;
336 }
337 if content.is_empty() {
338 let content;
339 syn::parenthesized!(content in input);
340 let first_expr: Expr = content.parse()?;
341 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
342 input.error("sql_forge!: section tuple must start with a string literal")
343 })?;
344 let span = first_expr.span();
345 content.parse::<Token![,]>()?;
346 let params = parse_params_source_expr(&content, false)?;
347 if content.peek(Token![,]) {
348 content.parse::<Token![,]>()?;
349 }
350 if !content.is_empty() {
351 return Err(content.error(
352 "sql_forge!: unexpected tokens after section-local parameter source",
353 ));
354 }
355 return Ok(SectionFragment { sql, span, params });
356 }
357 }
358 }
359 }
360
361 let expr: Expr = input.parse()?;
362 let sql = extract_lit_str(&expr).ok_or_else(|| {
363 input
364 .error("sql_forge!: section values must be string literals or (string literal, params)")
365 })?;
366 Ok(SectionFragment {
367 sql,
368 span: expr.span(),
369 params: ParamsSource::None,
370 })
371}
372
373fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
374 if input.peek(Token![match]) {
375 input.parse::<Token![match]>()?;
376 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
377 let content;
378 syn::braced!(content in input);
379 let mut arms = Vec::new();
380 while !content.is_empty() {
381 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
382 let guard = if content.peek(Token![if]) {
383 content.parse::<Token![if]>()?;
384 Some(content.parse::<Expr>()?)
385 } else {
386 None
387 };
388 content.parse::<Token![=>]>()?;
389 let value = parse_section_value(&content, width)?;
390 if content.peek(Token![,]) {
391 content.parse::<Token![,]>()?;
392 }
393 arms.push(SectionMatchArm { pat, guard, value });
394 }
395 return Ok(SectionValue::Match { expr, arms });
396 }
397
398 if width == 1 {
399 return Ok(SectionValue::Single(parse_section_fragment(input)?));
400 }
401
402 let content;
403 syn::parenthesized!(content in input);
404 let mut items = Vec::new();
405 while !content.is_empty() {
406 items.push(parse_section_value(&content, 1)?);
407 if content.is_empty() {
408 break;
409 }
410 content.parse::<Token![,]>()?;
411 }
412
413 if items.len() != width {
414 return Err(input.error(format!(
415 "sql_forge!: grouped section value must provide exactly {} items",
416 width,
417 )));
418 }
419
420 Ok(SectionValue::Grouped(items))
421}
422
423impl Parse for SqlForgeInput {
428 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
429 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
430 let sql = parse_sql_template(input)?;
431 (None, ResultSpec::None, false, sql)
432 } else if input.peek(kw::scalar) {
433 input.parse::<kw::scalar>()?;
434 let model: Type = input.parse()?;
435 input.parse::<Token![,]>()?;
436 let sql = parse_sql_template(input)?;
437 (None, ResultSpec::Single(Box::new(model)), true, sql)
438 } else if input.peek(syn::token::Paren) {
439 let result_map_kind = detect_parenthesized_map_kind(input)?;
440 match result_map_kind {
441 Some(MapKind::Results) => {
442 let result = ResultSpec::Group(parse_result_map(input)?);
443 input.parse::<Token![,]>()?;
444 let sql = parse_sql_template(input)?;
445 (None, result, false, sql)
446 }
447 _ => {
448 return Err(input.error(
449 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
450 ));
451 }
452 }
453 } else {
454 let first_ty: Type = input.parse()?;
455 input.parse::<Token![,]>()?;
456
457 if input.peek(LitStr) {
458 let model = first_ty;
459 let sql = parse_sql_template(input)?;
460 (None, ResultSpec::Single(Box::new(model)), false, sql)
461 } else if input.peek(kw::scalar) {
462 input.parse::<kw::scalar>()?;
463 let model: Type = input.parse()?;
464 input.parse::<Token![,]>()?;
465 let sql = parse_sql_template(input)?;
466 (
467 Some(first_ty),
468 ResultSpec::Single(Box::new(model)),
469 true,
470 sql,
471 )
472 } else if input.peek(syn::token::Paren)
473 && matches!(
474 detect_parenthesized_map_kind(input)?,
475 Some(MapKind::Results)
476 )
477 {
478 let result = ResultSpec::Group(parse_result_map(input)?);
479 input.parse::<Token![,]>()?;
480 let sql = parse_sql_template(input)?;
481 (Some(first_ty), result, false, sql)
482 } else {
483 let db = Some(first_ty);
484 let model: Type = input.parse()?;
485 input.parse::<Token![,]>()?;
486 let sql = parse_sql_template(input)?;
487 (db, ResultSpec::Single(Box::new(model)), false, sql)
488 }
489 };
490
491 let mut batch = None;
492 let mut params = ParamsSource::None;
493 let mut sections = Vec::new();
494 let mut seen_params = false;
495 let mut seen_sections = false;
496
497 if input.parse::<Token![,]>().is_ok() {
498 while !input.is_empty() {
499 if input.peek(Token![..]) {
500 if batch.is_some() {
501 return Err(
502 input.error("sql_forge!: only one batch source argument is allowed")
503 );
504 }
505 input.parse::<Token![..]>()?;
506 batch = Some(input.parse::<Expr>()?);
507 } else if input.peek(syn::token::Paren) {
508 match detect_parenthesized_map_kind(input)? {
509 Some(MapKind::Results) => {
510 return Err(input.error(
511 "sql_forge!: result maps are only allowed as the macro result argument",
512 ));
513 }
514 Some(MapKind::Params) => {
515 if seen_params {
516 return Err(
517 input.error("sql_forge!: only one parameter source is allowed")
518 );
519 }
520 params = ParamsSource::Map(parse_param_map(input)?);
521 seen_params = true;
522 }
523 Some(MapKind::Sections) => {
524 if seen_sections {
525 return Err(
526 input.error("sql_forge!: duplicate section map argument")
527 );
528 }
529 sections = parse_section_map(input)?;
530 seen_sections = true;
531 }
532 None => {
533 if seen_params {
534 return Err(
535 input.error("sql_forge!: only one parameter source is allowed")
536 );
537 }
538 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
539 seen_params = true;
540 }
541 }
542 } else {
543 if seen_params {
544 return Err(input.error("sql_forge!: only one parameter source is allowed"));
545 }
546 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
547 seen_params = true;
548 }
549
550 if input.parse::<Token![,]>().is_ok() {
551 continue;
552 }
553 break;
554 }
555 }
556
557 if !input.is_empty() {
558 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
559 }
560
561 Ok(Self {
562 db,
563 result,
564 force_scalar,
565 sql,
566 params,
567 sections,
568 batch,
569 })
570 }
571}
572
573fn resolve_db_from_env() -> Result<Type, String> {
578 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
579 return syn::parse_str::<Type>(&val).map_err(|err| {
580 format!(
581 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
582 val, err
583 )
584 });
585 }
586
587 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
588 Ok(d) => d,
589 Err(_) => {
590 return Err(
591 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
592 or configure [package.metadata.sql_forge] in Cargo.toml"
593 .to_string(),
594 );
595 }
596 };
597 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
598
599 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
600 format!(
601 "sql_forge!: failed to read {}: {}",
602 manifest_path.display(),
603 err
604 )
605 })?;
606
607 let value: toml::Value = toml::from_str(&cargo_toml)
608 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
609
610 let db_str = value
611 .get("package")
612 .and_then(|v| v.get("metadata"))
613 .and_then(|v| v.get("sql_forge"))
614 .and_then(|v| v.get("db"))
615 .and_then(|v| v.as_str())
616 .ok_or({
617 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
618 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
619 })?;
620
621 syn::parse_str::<Type>(db_str).map_err(|err| {
622 format!(
623 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
624 db_str, err
625 )
626 })
627}
628
629fn uses_dollar_params(db: &Type) -> bool {
630 let Type::Path(type_path) = db else {
631 return false;
632 };
633 type_path
634 .path
635 .segments
636 .last()
637 .is_some_and(|s| s.ident == "Postgres")
638}
639
640fn is_builtin_scalar_type(ty: &Type) -> bool {
641 let Type::Path(type_path) = ty else {
642 return false;
643 };
644
645 if type_path.qself.is_some()
646 || type_path.path.leading_colon.is_some()
647 || type_path.path.segments.len() != 1
648 {
649 return false;
650 }
651
652 let ident = &type_path.path.segments[0].ident;
653 ident == "i8"
654 || ident == "i16"
655 || ident == "i32"
656 || ident == "i64"
657 || ident == "isize"
658 || ident == "u8"
659 || ident == "u16"
660 || ident == "u32"
661 || ident == "u64"
662 || ident == "usize"
663 || ident == "f32"
664 || ident == "f64"
665 || ident == "bool"
666 || ident == "String"
667}
668
669fn scalar_output_type(model: &Type) -> Option<&Type> {
670 if is_builtin_scalar_type(model) {
671 return Some(model);
672 }
673 None
674}
675
676fn push_text_segment(out: &mut Vec<Segment>, text: String) {
677 if text.is_empty() {
678 return;
679 }
680 match out.last_mut() {
681 Some(Segment::Text(existing)) => existing.push_str(&text),
682 _ => out.push(Segment::Text(text)),
683 }
684}
685
686fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
687 let mut out = Vec::new();
688 let mut text = String::new();
689 let mut chars = sql.chars().peekable();
690
691 while let Some(ch) = chars.next() {
692 if ch != '{' {
693 text.push(ch);
694 continue;
695 }
696
697 if chars.peek() == Some(&'(') {
698 push_text_segment(&mut out, std::mem::take(&mut text));
699
700 let mut paren_depth = 0u32;
701 let mut content = String::new();
702 let mut found_close = false;
703 for ch in chars.by_ref() {
704 if ch == '{' {
705 return Err(
706 "sql_forge!: nested braces not allowed inside batch section".to_string()
707 );
708 }
709 if ch == '}' {
710 if paren_depth != 0 {
711 return Err(
712 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
713 .to_string(),
714 );
715 }
716 found_close = true;
717 break;
718 }
719 if ch == '(' {
720 paren_depth += 1;
721 } else if ch == ')' {
722 if paren_depth == 0 {
723 return Err(
724 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
725 .to_string(),
726 );
727 }
728 paren_depth -= 1;
729 }
730 content.push(ch);
731 }
732 if !found_close {
733 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
734 }
735 let parts = parse_text_parts(&content);
736 for part in &parts {
737 if let TextPart::Param { is_list: true, .. } = part {
738 return Err(
739 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
740 batch sections; use plain parameters (:name) instead"
741 .to_string(),
742 );
743 }
744 }
745 out.push(Segment::Batch { parts });
746 continue;
747 }
748
749 if chars.peek() != Some(&'#') {
750 text.push(ch);
751 continue;
752 }
753
754 chars.next();
755 push_text_segment(&mut out, std::mem::take(&mut text));
756
757 let mut name = String::new();
758 loop {
759 let Some(next) = chars.next() else {
760 return Err("sql_forge!: section placeholder without closing }".to_string());
761 };
762 if next == '}' {
763 break;
764 }
765 name.push(next);
766 }
767
768 if name.is_empty() {
769 return Err("sql_forge!: empty section placeholder name".to_string());
770 }
771
772 out.push(Segment::Section { name });
773 }
774
775 push_text_segment(&mut out, text);
776 Ok(out)
777}
778
779fn is_ident_start(ch: char) -> bool {
784 ch == '_' || ch.is_ascii_alphabetic()
785}
786
787fn is_ident_continue(ch: char) -> bool {
788 is_ident_start(ch) || ch.is_ascii_digit()
789}
790
791fn sanitize_backticked_alias_ident(content: &str) -> String {
792 let mut split_at = content.len();
793 for (idx, ch) in content.char_indices() {
794 if ch == '!' || ch == '?' || ch == ':' {
795 split_at = idx;
796 break;
797 }
798 }
799
800 if split_at == content.len() {
801 return content.to_string();
802 }
803
804 let base = content[..split_at].trim_end();
805 if base.is_empty() {
806 content.to_string()
807 } else {
808 base.to_string()
809 }
810}
811
812fn sanitize_runtime_sql_text(text: &str) -> String {
813 let mut out = String::with_capacity(text.len());
814 let mut chars = text.chars().peekable();
815
816 while let Some(ch) = chars.next() {
817 if ch != '`' {
818 out.push(ch);
819 continue;
820 }
821
822 let mut content = String::new();
823 let mut closed = false;
824
825 for next in chars.by_ref() {
826 if next == '`' {
827 closed = true;
828 break;
829 }
830 content.push(next);
831 }
832
833 if closed {
834 out.push('`');
835 out.push_str(&sanitize_backticked_alias_ident(&content));
836 out.push('`');
837 } else {
838 out.push('`');
839 out.push_str(&content);
840 break;
841 }
842 }
843
844 out
845}
846
847fn parse_text_parts(text: &str) -> Vec<TextPart> {
848 let mut parts = Vec::new();
849 let mut last = 0usize;
850 let mut iter = text.char_indices().peekable();
851
852 while let Some((idx, ch)) = iter.next() {
853 if ch != ':' {
854 continue;
855 }
856
857 let Some(&(next_idx, next_ch)) = iter.peek() else {
858 continue;
859 };
860
861 if !is_ident_start(next_ch) {
862 continue;
863 }
864
865 if text[..idx].ends_with(':') {
866 continue;
867 }
868
869 if last < idx {
870 parts.push(TextPart::Lit(text[last..idx].to_string()));
871 }
872
873 iter.next();
874
875 let mut name = String::new();
876 name.push(next_ch);
877 let mut end = next_idx + next_ch.len_utf8();
878
879 while let Some(&(j, c)) = iter.peek() {
880 if is_ident_continue(c) {
881 name.push(c);
882 end = j + c.len_utf8();
883 iter.next();
884 } else {
885 break;
886 }
887 }
888
889 let mut is_list = false;
890 if text[end..].starts_with("[]") {
891 is_list = true;
892 end += 2;
893 }
894
895 parts.push(TextPart::Param { name, is_list });
896 last = end;
897 }
898
899 if last < text.len() {
900 parts.push(TextPart::Lit(text[last..].to_string()));
901 }
902
903 parts
904}
905
906fn render_validator_text(
907 text: &str,
908 use_dollar_params: bool,
909 param_offset: &mut usize,
910 list_count: usize,
911) -> (String, Vec<(String, bool)>) {
912 let mut out_sql = String::new();
913 let mut occurrences = Vec::new();
914
915 for part in parse_text_parts(text) {
916 match part {
917 TextPart::Lit(lit) => out_sql.push_str(&lit),
918 TextPart::Param { name, is_list } => {
919 if is_list && list_count > 1 {
920 let slots: Vec<String> = if use_dollar_params {
921 (0..list_count)
922 .map(|i| format!("${}", *param_offset + i + 1))
923 .collect()
924 } else {
925 (0..list_count).map(|_| "?".to_string()).collect()
926 };
927 if use_dollar_params {
928 *param_offset += list_count;
929 }
930 out_sql.push_str(&slots.join(", "));
931 } else if use_dollar_params {
932 *param_offset += 1;
933 write!(out_sql, "${}", *param_offset).unwrap();
934 } else {
935 out_sql.push('?');
936 }
937 occurrences.push((name, is_list));
938 }
939 }
940 }
941
942 (out_sql, occurrences)
943}
944
945fn strip_expr(expr: &Expr) -> &Expr {
946 match expr {
947 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
948 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
949 Expr::Block(ExprBlock { block, .. }) => {
950 if block.stmts.len() != 1 {
951 return expr;
952 }
953 match &block.stmts[0] {
954 Stmt::Expr(inner, None) => strip_expr(inner),
955 _ => expr,
956 }
957 }
958 _ => expr,
959 }
960}
961
962fn extract_lit_str(expr: &Expr) -> Option<String> {
963 match strip_expr(expr) {
964 Expr::Lit(ExprLit {
965 lit: Lit::Str(lit), ..
966 }) => Some(lit.value()),
967 _ => None,
968 }
969}
970
971fn result_flag_ident(name: &str) -> syn::Ident {
976 format_ident!("__enhanced_result_flag_{}", name)
977}
978
979fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
983 fn walk(stream: TokenStream2) -> TokenStream2 {
984 let mut out = TokenStream2::new();
985 let iter = stream.into_iter().peekable();
986
987 for token in iter {
988 match token {
989 TokenTree::Group(group) => {
990 if group.delimiter() == Delimiter::Brace {
991 let mut inner = group.stream().into_iter();
992 let first = inner.next();
993 let second = inner.next();
994 let third = inner.next();
995
996 if let (
997 Some(TokenTree::Punct(p)),
998 Some(TokenTree::Ident(name_ident)),
999 None,
1000 ) = (first, second, third)
1001 {
1002 if p.as_char() == '>' {
1003 let ident = result_flag_ident(&name_ident.to_string());
1004 out.extend(std::iter::once(TokenTree::Ident(ident)));
1005 continue;
1006 }
1007 }
1008 }
1009
1010 let new_inner = walk(group.stream());
1011 let mut new_group = Group::new(group.delimiter(), new_inner);
1012 new_group.set_span(group.span());
1013 out.extend(std::iter::once(TokenTree::Group(new_group)));
1014 }
1015 other => out.extend(std::iter::once(other)),
1016 }
1017 }
1018
1019 out
1020 }
1021
1022 walk(input)
1023}
1024
1025fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1026 keys.iter()
1027 .map(|key| {
1028 let ident = result_flag_ident(key);
1029 let enabled = Some(key.as_str()) == active_key;
1030 quote! { let #ident: bool = #enabled; }
1031 })
1032 .collect()
1033}
1034
1035fn transpose_section_case_matrix(
1036 case_matrix: Vec<Vec<SectionFragment>>,
1037 width: usize,
1038) -> Result<Vec<Vec<SectionFragment>>, String> {
1039 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1040
1041 for row in case_matrix {
1042 if row.len() != width {
1043 return Err(
1044 "sql_forge!: grouped sections must return one item per section".to_string(),
1045 );
1046 }
1047 for (section_idx, fragment) in row.into_iter().enumerate() {
1048 per_section[section_idx].push(fragment);
1049 }
1050 }
1051
1052 Ok(per_section)
1053}
1054
1055fn collect_section_case_matrix(
1056 value: SectionValue,
1057 width: usize,
1058 active_key: Option<&str>,
1059) -> Result<Vec<Vec<SectionFragment>>, String> {
1060 match value {
1061 SectionValue::Single(fragment) => {
1062 if width != 1 {
1063 return Err(
1064 "sql_forge!: grouped sections must return one item per section".to_string(),
1065 );
1066 }
1067 Ok(vec![vec![fragment]])
1068 }
1069 SectionValue::Grouped(values) => {
1070 if values.len() != width {
1071 return Err(
1072 "sql_forge!: grouped sections must return one item per section".to_string(),
1073 );
1074 }
1075
1076 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1077 let mut nmax = 1usize;
1078
1079 for value in values {
1080 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1081 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1082 for mut row in item_matrix {
1083 let fragment = row.pop().ok_or_else(|| {
1084 "sql_forge!: grouped sections must return one item per section".to_string()
1085 })?;
1086 if !row.is_empty() {
1087 return Err(
1088 "sql_forge!: grouped sections must return one item per section"
1089 .to_string(),
1090 );
1091 }
1092 item_variants.push(fragment);
1093 }
1094 if item_variants.is_empty() {
1095 return Err("sql_forge!: section match must have at least one arm".to_string());
1096 }
1097 nmax = nmax.max(item_variants.len());
1098 variants_by_section.push(item_variants);
1099 }
1100
1101 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1102 for case_idx in 0..nmax {
1103 let mut row = Vec::<SectionFragment>::with_capacity(width);
1104 for variants in &variants_by_section {
1105 row.push(variants[case_idx % variants.len()].clone());
1106 }
1107 case_matrix.push(row);
1108 }
1109
1110 Ok(case_matrix)
1111 }
1112 SectionValue::Match { expr, arms } => {
1113 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1114
1115 if let Some(key) = expr_result_flag_key(&expr) {
1116 let target = active_key == Some(key.as_str());
1117 for arm in arms {
1118 if arm.guard.is_none() {
1119 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1120 continue;
1121 }
1122 }
1123 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1124 wrap_section_case_matrix_for_match_arm(
1125 &mut arm_cases,
1126 &expr,
1127 &arm.pat,
1128 arm.guard.as_ref(),
1129 );
1130 case_matrix.extend(arm_cases);
1131 }
1132 } else {
1133 for arm in arms {
1134 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1135 wrap_section_case_matrix_for_match_arm(
1136 &mut arm_cases,
1137 &expr,
1138 &arm.pat,
1139 arm.guard.as_ref(),
1140 );
1141 case_matrix.extend(arm_cases);
1142 }
1143 }
1144
1145 if case_matrix.is_empty() {
1146 return Err("sql_forge!: section match must have at least one arm".to_string());
1147 }
1148
1149 Ok(case_matrix)
1150 }
1151 }
1152}
1153
1154fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1157 let match_expr = match_expr.clone();
1158 let pat = pat.clone();
1159 let pattern_binds_values = match &pat {
1160 Pat::Ident(_) => true,
1161 Pat::Or(pat_or) => pat_or
1162 .cases
1163 .iter()
1164 .any(|case| matches!(case, Pat::Ident(_))),
1165 Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1166 Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1167 Pat::Slice(pat_slice) => pat_slice
1168 .elems
1169 .iter()
1170 .any(|elem| matches!(elem, Pat::Ident(_))),
1171 Pat::Struct(pat_struct) => pat_struct
1172 .fields
1173 .iter()
1174 .any(|field| matches!(*field.pat, Pat::Ident(_))),
1175 Pat::Tuple(pat_tuple) => pat_tuple
1176 .elems
1177 .iter()
1178 .any(|elem| matches!(elem, Pat::Ident(_))),
1179 Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1180 .elems
1181 .iter()
1182 .any(|elem| matches!(elem, Pat::Ident(_))),
1183 Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1184 _ => false,
1185 };
1186
1187 if pattern_binds_values {
1188 if let Some(guard) = guard.cloned() {
1189 parse_quote! {
1190 match &(#match_expr) {
1191 #pat if #guard => { #expr },
1192 _ => unreachable!("sql_forge!: validator arm mismatch"),
1193 }
1194 }
1195 } else {
1196 parse_quote! {
1197 match &(#match_expr) {
1198 #pat => { #expr },
1199 _ => unreachable!("sql_forge!: validator arm mismatch"),
1200 }
1201 }
1202 }
1203 } else if let Some(guard) = guard.cloned() {
1204 parse_quote! {
1205 match &(#match_expr) {
1206 #pat if #guard => { &(#expr) },
1207 _ => unreachable!("sql_forge!: validator arm mismatch"),
1208 }
1209 }
1210 } else {
1211 parse_quote! {
1212 match &(#match_expr) {
1213 #pat => { &(#expr) },
1214 _ => unreachable!("sql_forge!: validator arm mismatch"),
1215 }
1216 }
1217 }
1218}
1219
1220fn wrap_params_source_for_match_arm(
1221 params: &mut ParamsSource,
1222 match_expr: &Expr,
1223 pat: &Pat,
1224 guard: Option<&Expr>,
1225) {
1226 match params {
1227 ParamsSource::None => {}
1228 ParamsSource::Map(entries) => {
1229 for entry in entries {
1230 entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1231 }
1232 }
1233 ParamsSource::Struct(expr) => {
1234 **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1235 }
1236 }
1237}
1238
1239fn wrap_section_case_matrix_for_match_arm(
1240 case_matrix: &mut [Vec<SectionFragment>],
1241 match_expr: &Expr,
1242 pat: &Pat,
1243 guard: Option<&Expr>,
1244) {
1245 for row in case_matrix {
1246 for fragment in row {
1247 wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1248 }
1249 }
1250}
1251
1252fn collect_section_variants(
1262 value: SectionValue,
1263 width: usize,
1264) -> Result<Vec<Vec<SectionFragment>>, String> {
1265 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1266}
1267
1268fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1269 match strip_expr(expr) {
1270 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1271 let name = path.path.segments[0].ident.to_string();
1272 name.strip_prefix("__enhanced_result_flag_")
1273 .map(|v| v.to_string())
1274 }
1275 _ => None,
1276 }
1277}
1278
1279fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1280 match pat {
1281 Pat::Lit(expr_lit) => match &expr_lit.lit {
1282 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1283 _ => None,
1284 },
1285 Pat::Wild(_) => Some(true),
1286 _ => None,
1287 }
1288}
1289
1290fn collect_section_variants_for_result(
1295 value: SectionValue,
1296 width: usize,
1297 active_key: Option<&str>,
1298) -> Result<Vec<Vec<SectionFragment>>, String> {
1299 transpose_section_case_matrix(
1300 collect_section_case_matrix(value, width, active_key)?,
1301 width,
1302 )
1303}
1304
1305fn build_param_bindings(
1313 params: &ParamsSource,
1314 used_param_names: &[String],
1315 prefix: &str,
1316 for_validator: bool,
1317 enforce_usage_check: bool,
1318) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1319 let mut declared_params = HashMap::<String, syn::Ident>::new();
1320 let mut bindings = Vec::<TokenStream2>::new();
1321
1322 match params {
1323 ParamsSource::None => {}
1324 ParamsSource::Map(entries) => {
1325 for entry in entries {
1326 let key = entry.name.to_string();
1327 if declared_params.contains_key(&key) {
1328 return Err(syn::Error::new(
1329 entry.name.span(),
1330 "sql_forge!: duplicated parameter mapping",
1331 )
1332 .to_compile_error()
1333 .into());
1334 }
1335 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1336 return Err(syn::Error::new(
1337 entry.name.span(),
1338 format!(
1339 "sql_forge!: parameter :{} is unused in the SQL template",
1340 key,
1341 ),
1342 )
1343 .to_compile_error()
1344 .into());
1345 }
1346 let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1347 let expr = &entry.expr;
1348 if for_validator {
1349 bindings.push(quote! {
1350 let #local_ident = &(#expr);
1351 });
1352 } else {
1353 bindings.push(quote! {
1354 let #local_ident = #expr;
1355 });
1356 }
1357 declared_params.insert(key, local_ident);
1358 }
1359 }
1360 ParamsSource::Struct(expr) => {
1361 let source_ident = format_ident!("__enhanced_source_{}", prefix);
1362 bindings.push(quote! {
1363 let #source_ident = &(#expr);
1364 });
1365 for name in used_param_names {
1366 let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1367 let field_ident = format_ident!("{}", name);
1368 if for_validator {
1369 bindings.push(quote! {
1370 let #local_ident = &#source_ident.#field_ident;
1371 });
1372 } else {
1373 bindings.push(quote! {
1374 let #local_ident = #source_ident.#field_ident;
1375 });
1376 }
1377 declared_params.insert(name.to_string(), local_ident);
1378 }
1379 }
1380 }
1381
1382 Ok((declared_params, bindings))
1383}
1384
1385struct ValidatorRenderContext<'a> {
1386 local_params: &'a HashMap<String, syn::Ident>,
1387 top_level_params: &'a HashMap<String, syn::Ident>,
1388 allow_top_level_fallback: bool,
1389 use_dollar_params: bool,
1390 sql_span: Span,
1391 list_count: usize,
1392}
1393
1394fn render_validator_args(
1399 sql: &str,
1400 param_offset: &mut usize,
1401 arg_index: &mut usize,
1402 context: &ValidatorRenderContext<'_>,
1403) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1404 let (rendered_sql, occurrences) = render_validator_text(
1405 sql,
1406 context.use_dollar_params,
1407 param_offset,
1408 context.list_count,
1409 );
1410 let mut setup = Vec::<TokenStream2>::new();
1411 let mut args = Vec::<TokenStream2>::new();
1412
1413 for (name, is_list) in occurrences {
1414 let local_ident = if context.allow_top_level_fallback {
1415 context
1416 .local_params
1417 .get(&name)
1418 .or_else(|| context.top_level_params.get(&name))
1419 } else {
1420 context.local_params.get(&name)
1421 };
1422
1423 let Some(local_ident) = local_ident else {
1424 return Err(syn::Error::new(
1425 context.sql_span,
1426 format!("sql_forge!: parameter :{} has no mapping", name),
1427 )
1428 .to_compile_error()
1429 .into());
1430 };
1431
1432 if is_list {
1433 for _ in 0..context.list_count {
1434 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1435 *arg_index += 1;
1436 setup.push(quote! {
1437 let #value_ident = sql_forge::sql_forge_validator_value(
1438 (#local_ident)
1439 .as_slice()
1440 .first()
1441 .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1442 );
1443 });
1444 args.push(quote! { #value_ident });
1445 }
1446 } else {
1447 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1448 *arg_index += 1;
1449 setup.push(quote! {
1450 let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1451 });
1452 args.push(quote! { #value_ident });
1453 }
1454 }
1455
1456 Ok((rendered_sql, setup, args))
1457}
1458
1459fn render_runtime_fragment(
1466 fragment: &SectionFragment,
1467 local_params: &HashMap<String, syn::Ident>,
1468) -> Result<TokenStream2, TokenStream> {
1469 let mut steps = Vec::<TokenStream2>::new();
1470
1471 for part in parse_text_parts(&fragment.sql) {
1472 match part {
1473 TextPart::Lit(lit) => {
1474 let lit_str = LitStr::new(&lit, fragment.span);
1475 steps.push(quote! { __builder.push(#lit_str); });
1476 }
1477 TextPart::Param { name, is_list } => {
1478 let Some(local_ident) = local_params.get(&name) else {
1479 return Err(syn::Error::new(
1480 fragment.span,
1481 format!("sql_forge!: parameter :{} has no mapping", name),
1482 )
1483 .to_compile_error()
1484 .into());
1485 };
1486
1487 if is_list {
1488 steps.push(quote! {
1489 let __enhanced_values = #local_ident;
1490 let mut __separated = __builder.separated(", ");
1491 for __value in __enhanced_values {
1492 __separated.push_bind(__value);
1493 }
1494 });
1495 } else {
1496 steps.push(quote! {
1497 __builder.push_bind(#local_ident);
1498 });
1499 }
1500 }
1501 }
1502 }
1503
1504 Ok(quote! { #( #steps )* })
1505}
1506
1507fn build_section_runtime_action(
1508 value: &SectionValue,
1509 section_idx: usize,
1510 prefix: &str,
1511) -> Result<TokenStream2, TokenStream> {
1512 match value {
1513 SectionValue::Single(fragment) => {
1514 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1515 let (local_params, bindings) =
1516 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1517 let body = render_runtime_fragment(fragment, &local_params)?;
1518 Ok(quote! {{ #( #bindings )* #body }})
1519 }
1520 SectionValue::Grouped(fragments) => build_section_runtime_action(
1521 &fragments[section_idx],
1522 0,
1523 &format!("{}_grouped_{}", prefix, section_idx),
1524 ),
1525 SectionValue::Match { expr, arms } => {
1526 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1527 .iter()
1528 .enumerate()
1529 .map(|(arm_idx, arm)| {
1530 let pat = &arm.pat;
1531 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1532 let body = build_section_runtime_action(
1533 &arm.value,
1534 section_idx,
1535 &format!("{}_{}", prefix, arm_idx),
1536 )?;
1537 Ok::<TokenStream2, TokenStream>(quote! { #pat #guard_tokens => #body })
1538 })
1539 .collect();
1540 let arm_tokens = arm_tokens?;
1541 Ok(quote! {
1542 match #expr {
1543 #( #arm_tokens ),*
1544 }
1545 })
1546 }
1547 }
1548}
1549
1550fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1551 let mut names = Vec::new();
1552 let mut seen = HashSet::<String>::new();
1553
1554 for segment in segments {
1555 match segment {
1556 Segment::Text(text) => {
1557 for name in collect_used_param_names_in_sql(text) {
1558 if seen.insert(name.clone()) {
1559 names.push(name);
1560 }
1561 }
1562 }
1563 Segment::Batch { parts } => {
1564 for part in parts {
1565 if let TextPart::Param { name, .. } = part {
1566 if seen.insert(name.clone()) {
1567 names.push(name.clone());
1568 }
1569 }
1570 }
1571 }
1572 _ => {}
1573 }
1574 }
1575
1576 names
1577}
1578
1579fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1580 let mut names = Vec::new();
1581 let mut seen = HashSet::<String>::new();
1582 for part in parse_text_parts(sql) {
1583 if let TextPart::Param { name, .. } = part {
1584 if seen.insert(name.to_string()) {
1585 names.push(name);
1586 }
1587 }
1588 }
1589 names
1590}
1591
1592#[proc_macro]
1818#[allow(clippy::too_many_lines)]
1819pub fn sql_forge(input: TokenStream) -> TokenStream {
1820 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1822 let SqlForgeInput {
1823 db,
1824 result,
1825 force_scalar,
1826 sql,
1827 params,
1828 sections,
1829 batch,
1830 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1831 Ok(v) => v,
1832 Err(err) => return err.to_compile_error().into(),
1833 };
1834
1835 let db = match db {
1837 Some(db) => db,
1838 None => match resolve_db_from_env() {
1839 Ok(db) => db,
1840 Err(msg) => {
1841 return syn::Error::new(Span::call_site(), msg)
1842 .to_compile_error()
1843 .into();
1844 }
1845 },
1846 };
1847
1848 let use_dollar_params = uses_dollar_params(&db);
1849 let is_sqlite = if let syn::Type::Path(type_path) = &db {
1850 type_path
1851 .path
1852 .segments
1853 .last()
1854 .is_some_and(|s| s.ident == "Sqlite")
1855 } else {
1856 false
1857 };
1858 let list_count: usize = if is_sqlite { 1 } else { 3 };
1859
1860 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1864 ResultSpec::None => {
1865 vec![(None, None, None)]
1866 }
1867 ResultSpec::Single(ref model) => {
1868 let model_ty = (**model).clone();
1869 let scalar = if force_scalar {
1870 Some(model_ty.clone())
1871 } else {
1872 scalar_output_type(model.as_ref()).cloned()
1873 };
1874 vec![(None, Some(model_ty), scalar)]
1875 }
1876 ResultSpec::Group(ref cases) => {
1877 if force_scalar {
1878 return syn::Error::new(
1879 Span::call_site(),
1880 "sql_forge!: scalar mode is not supported for grouped result maps",
1881 )
1882 .to_compile_error()
1883 .into();
1884 }
1885
1886 let mut out = Vec::new();
1887 let mut seen = HashSet::new();
1888 for case in cases {
1889 let key = case.name.to_string();
1890 if !seen.insert(key.clone()) {
1891 return syn::Error::new(
1892 case.name.span(),
1893 "sql_forge!: duplicated key in result map",
1894 )
1895 .to_compile_error()
1896 .into();
1897 }
1898
1899 let model = case.model.clone();
1900 let scalar = if case.force_scalar {
1901 Some(model.clone())
1902 } else {
1903 scalar_output_type(&case.model).cloned()
1904 };
1905 out.push((Some(key), Some(model), scalar));
1906 }
1907 out
1908 }
1909 };
1910 let group_result_keys: Vec<String> = result_cases
1911 .iter()
1912 .filter_map(|(key, _, _)| key.as_ref().cloned())
1913 .collect();
1914 let is_grouped_result = !group_result_keys.is_empty();
1915 let sql_span = sql.span();
1916
1917 let segments = match sql.into_segments() {
1919 Ok(segments) => segments,
1920 Err(msg) => {
1921 return syn::Error::new(sql_span, msg).to_compile_error().into();
1922 }
1923 };
1924
1925 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
1926 match (&batch, has_batch_segment) {
1927 (None, true) => {
1928 return syn::Error::new(
1929 sql_span,
1930 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
1931 was provided"
1932 )
1933 .to_compile_error()
1934 .into();
1935 }
1936 (Some(_), false) => {
1937 return syn::Error::new(
1938 sql_span,
1939 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
1940 batch section",
1941 )
1942 .to_compile_error()
1943 .into();
1944 }
1945 _ => {}
1946 }
1947
1948 let used_param_names = collect_used_param_names(&segments);
1949
1950 let batch_param_names: std::collections::HashSet<String> = segments
1955 .iter()
1956 .filter_map(|s| {
1957 if let Segment::Batch { parts } = s {
1958 Some(parts.iter().filter_map(|p| {
1959 if let TextPart::Param { name, .. } = p {
1960 Some(name.clone())
1961 } else {
1962 None
1963 }
1964 }))
1965 } else {
1966 None
1967 }
1968 })
1969 .flatten()
1970 .collect();
1971 let top_level_used_names: Vec<String> = used_param_names
1972 .iter()
1973 .filter(|n| !batch_param_names.contains(*n))
1974 .cloned()
1975 .collect();
1976
1977 let (declared_params, validator_param_bindings) =
1979 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
1980 Ok(v) => v,
1981 Err(err) => return err,
1982 };
1983
1984 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
1985
1986 for assign in §ions {
1988 let SectionAssign { names, value } = assign;
1989
1990 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
1992 for (section_idx, name_ident) in names.iter().enumerate() {
1993 let name = name_ident.to_string();
1994 if runtime_section_actions.contains_key(&name) {
1995 return syn::Error::new(
1996 name_ident.span(),
1997 "sql_forge!: duplicated section mapping",
1998 )
1999 .to_compile_error()
2000 .into();
2001 }
2002 let action = match build_section_runtime_action(
2003 value,
2004 section_idx,
2005 &format!("section_{}", name),
2006 ) {
2007 Ok(action) => action,
2008 Err(err) => return err,
2009 };
2010 named_actions.push((name, action));
2011 }
2012
2013 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2015 return syn::Error::new(names[0].span(), msg)
2016 .to_compile_error()
2017 .into();
2018 }
2019
2020 for (name, action) in named_actions {
2021 runtime_section_actions.insert(name, action);
2022 }
2023 }
2024
2025 let sql_section_names: std::collections::HashSet<&str> = segments
2026 .iter()
2027 .filter_map(|seg| {
2028 if let Segment::Section { name } = seg {
2029 Some(name.as_str())
2030 } else {
2031 None
2032 }
2033 })
2034 .collect();
2035 for name in runtime_section_actions.keys() {
2036 if !sql_section_names.contains(name.as_str()) {
2037 return syn::Error::new(
2038 sql_span,
2039 format!(
2040 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2041 name, name,
2042 ),
2043 )
2044 .to_compile_error()
2045 .into();
2046 }
2047 }
2048
2049 let mut generated_query_defs = Vec::<TokenStream2>::new();
2051 let mut generated_query_values = Vec::<TokenStream2>::new();
2052 let mut group_field_defs = Vec::<TokenStream2>::new();
2053 let mut group_method_defs = Vec::<TokenStream2>::new();
2054 let mut group_field_idents = Vec::<syn::Ident>::new();
2055 let mut group_field_tys = Vec::<TokenStream2>::new();
2056 let mut group_trait_impls = Vec::<TokenStream2>::new();
2057
2058 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2059
2060 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2061 let suffix = result_key.as_deref().unwrap_or("single");
2062 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2063 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2064
2065 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2066
2067 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2068 for assign in §ions {
2069 let SectionAssign { names, value } = assign;
2070 let variants_by_section = match collect_section_variants_for_result(
2071 value.clone(),
2072 names.len(),
2073 result_key.as_deref(),
2074 ) {
2075 Ok(v) => v,
2076 Err(msg) => {
2077 return syn::Error::new(names[0].span(), msg)
2078 .to_compile_error()
2079 .into();
2080 }
2081 };
2082
2083 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2084 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2085 }
2086 }
2087
2088 let mut nmax = 1usize;
2089 for segment in &segments {
2090 if let Segment::Section { name } = segment {
2091 if let Some(variants) = section_variants_for_validation.get(name) {
2092 if variants.is_empty() {
2093 return syn::Error::new(
2094 sql_span,
2095 format!("sql_forge!: section {{#{}}} has no possible variants", name),
2096 )
2097 .to_compile_error()
2098 .into();
2099 }
2100 nmax = nmax.max(variants.len());
2101 } else {
2102 return syn::Error::new(
2103 sql_span,
2104 format!("sql_forge!: section {{#{}}} has no mapping", name),
2105 )
2106 .to_compile_error()
2107 .into();
2108 }
2109 }
2110 }
2111
2112 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2113 for case_idx in 0..nmax {
2114 let mut sql_case = String::new();
2115 let mut case_setup = Vec::<TokenStream2>::new();
2116 let mut case_args = Vec::<TokenStream2>::new();
2117 let mut param_offset = 0usize;
2118 let mut arg_index = 0usize;
2119 let empty_params = HashMap::<String, syn::Ident>::new();
2120 let root_validator_context = ValidatorRenderContext {
2121 local_params: &empty_params,
2122 top_level_params: &declared_params,
2123 allow_top_level_fallback: true,
2124 use_dollar_params,
2125 sql_span,
2126 list_count,
2127 };
2128
2129 for segment in &segments {
2130 match segment {
2131 Segment::Text(text) => {
2132 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2133 text,
2134 &mut param_offset,
2135 &mut arg_index,
2136 &root_validator_context,
2137 ) {
2138 Ok(value) => value,
2139 Err(err) => return err,
2140 };
2141 sql_case.push_str(&chunk_sql);
2142 case_setup.extend(chunk_setup);
2143 case_args.extend(chunk_args);
2144 }
2145 Segment::Section { name } => {
2146 let Some(variants) = section_variants_for_validation.get(name) else {
2147 return syn::Error::new(
2148 sql_span,
2149 format!("sql_forge!: section {{#{}}} has no mapping", name),
2150 )
2151 .to_compile_error()
2152 .into();
2153 };
2154
2155 let fragment = &variants[case_idx % variants.len()];
2156 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2157 let (local_params, bindings) = match build_param_bindings(
2158 &fragment.params,
2159 &used_param_names,
2160 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2161 true,
2162 true,
2163 ) {
2164 Ok(value) => value,
2165 Err(err) => return err,
2166 };
2167 let section_validator_context = ValidatorRenderContext {
2168 local_params: &local_params,
2169 top_level_params: &declared_params,
2170 allow_top_level_fallback: false,
2171 use_dollar_params,
2172 sql_span: fragment.span,
2173 list_count,
2174 };
2175 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2176 &fragment.sql,
2177 &mut param_offset,
2178 &mut arg_index,
2179 §ion_validator_context,
2180 ) {
2181 Ok(value) => value,
2182 Err(err) => return err,
2183 };
2184 sql_case.push_str(&chunk_sql);
2185 case_setup.extend(bindings);
2186 case_setup.extend(chunk_setup);
2187 case_args.extend(chunk_args);
2188 }
2189 Segment::Batch { parts } => {
2190 let mut first = true;
2191 for _ in 0..list_count {
2192 let sep = if first { "" } else { ", " };
2193 first = false;
2194 sql_case.push_str(sep);
2195 for tp in parts {
2196 match tp {
2197 TextPart::Lit(lit) => sql_case.push_str(lit),
2198 TextPart::Param { name, .. } => {
2199 if let Some(batch_expr) = &batch {
2200 let field_ident = format_ident!("{}", name);
2201 if use_dollar_params {
2202 param_offset += 1;
2203 write!(sql_case, "${}", param_offset).unwrap();
2204 } else {
2205 sql_case.push('?');
2206 }
2207 case_args.push(quote! { #batch_expr[0].#field_ident });
2208 } else if use_dollar_params {
2209 param_offset += 1;
2210 write!(sql_case, "${}", param_offset).unwrap();
2211 } else {
2212 sql_case.push('?');
2213 }
2214 }
2215 }
2216 }
2217 }
2218 }
2219 }
2220 }
2221
2222 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2223 }
2224
2225 let mut validator_invocations = Vec::<TokenStream2>::new();
2226 for (sql_lit, case_setup, args) in &validator_cases {
2227 if model_opt.is_none() {
2228 if args.is_empty() {
2229 validator_invocations.push(quote! {
2230 {
2231 #( #case_setup )*
2232 let _ = sqlx::query_scalar!(
2233 #sql_lit,
2234 );
2235 }
2236 });
2237 } else {
2238 validator_invocations.push(quote! {
2239 {
2240 #( #case_setup )*
2241 let _ = sqlx::query_scalar!(
2242 #sql_lit,
2243 #( #args ),*
2244 );
2245 }
2246 });
2247 }
2248 } else if let Some(scalar_ty) = scalar_model_ty {
2249 if args.is_empty() {
2250 validator_invocations.push(quote! {
2251 {
2252 #( #case_setup )*
2253 let _ = sqlx::query_scalar!(
2254 #sql_lit,
2255 );
2256 }
2257 });
2258 } else {
2259 validator_invocations.push(quote! {
2260 {
2261 #( #case_setup )*
2262 let _ = sqlx::query_scalar!(
2263 #sql_lit,
2264 #( #args ),*
2265 );
2266 }
2267 });
2268 }
2269 let _ = scalar_ty;
2270 } else if args.is_empty() {
2271 validator_invocations.push(quote! {
2272 {
2273 #( #case_setup )*
2274 let _ = sqlx::query_as!(
2275 __EnhancedModel,
2276 #sql_lit,
2277 );
2278 }
2279 });
2280 } else {
2281 validator_invocations.push(quote! {
2282 {
2283 #( #case_setup )*
2284 let _ = sqlx::query_as!(
2285 __EnhancedModel,
2286 #sql_lit,
2287 #( #args ),*
2288 );
2289 }
2290 });
2291 }
2292 }
2293
2294 let model_alias = if let Some(model) = model_opt {
2295 if scalar_model_ty.is_none() {
2296 quote! { type __EnhancedModel = #model; }
2297 } else {
2298 quote! {}
2299 }
2300 } else {
2301 quote! {}
2302 };
2303 grouped_validator_invocations.push(quote! {
2304 {
2305 #( #flag_bindings )*
2306 #model_alias
2307 #( #validator_invocations )*
2308 }
2309 });
2310
2311 let (runtime_declared_params, runtime_param_bindings) =
2312 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2313 Ok(v) => v,
2314 Err(err) => return err,
2315 };
2316
2317 let mut runtime_steps = Vec::<TokenStream2>::new();
2318 for (seg_idx, segment) in segments.iter().enumerate() {
2319 match segment {
2320 Segment::Text(text) => {
2321 for part in parse_text_parts(text) {
2322 match part {
2323 TextPart::Lit(lit) => {
2324 let lit = sanitize_runtime_sql_text(&lit);
2325 let lit_str = LitStr::new(&lit, sql_span);
2326 runtime_steps.push(quote! {
2327 __builder.push(#lit_str);
2328 });
2329 }
2330 TextPart::Param { name, is_list } => {
2331 let Some(local_ident) = runtime_declared_params.get(&name) else {
2332 return syn::Error::new(
2333 sql_span,
2334 format!("sql_forge!: parameter :{} has no mapping", name),
2335 )
2336 .to_compile_error()
2337 .into();
2338 };
2339
2340 if is_list {
2341 runtime_steps.push(quote! {
2342 let __enhanced_values = #local_ident;
2343 let mut __separated = __builder.separated(", ");
2344 for __value in __enhanced_values {
2345 __separated.push_bind(__value);
2346 }
2347 });
2348 } else {
2349 runtime_steps.push(quote! {
2350 __builder.push_bind(#local_ident);
2351 });
2352 }
2353 }
2354 }
2355 }
2356 }
2357 Segment::Section { name } => {
2358 let Some(section_action) = runtime_section_actions.get(name) else {
2359 let _ = seg_idx;
2360 return syn::Error::new(
2361 sql_span,
2362 format!("sql_forge!: section {{#{}}} has no mapping", name),
2363 )
2364 .to_compile_error()
2365 .into();
2366 };
2367 runtime_steps.push(quote! {
2368 #section_action
2369 });
2370 }
2371 Segment::Batch { parts } => {
2372 if let Some(batch_expr) = &batch {
2373 let mut body = Vec::<TokenStream2>::new();
2374 for part in parts {
2375 match part {
2376 TextPart::Lit(lit) => {
2377 let lit_str = LitStr::new(lit, sql_span);
2378 body.push(quote! {
2379 __builder.push(#lit_str);
2380 });
2381 }
2382 TextPart::Param { name, .. } => {
2383 let field_ident = format_ident!("{}", name);
2384 body.push(quote! {
2385 __builder.push_bind(__item.#field_ident);
2386 });
2387 }
2388 }
2389 }
2390 runtime_steps.push(quote! {
2391 {
2392 let mut __first = true;
2393 for __item in #batch_expr {
2394 if !__first {
2395 __builder.push(", ");
2396 }
2397 __first = false;
2398 #( #body )*
2399 }
2400 }
2401 });
2402 }
2403 }
2404 }
2405 }
2406
2407 let exec_methods = if model_opt.is_none() {
2408 quote! {
2409 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2410 where
2411 E: sqlx::Executor<'e, Database = #db>,
2412 {
2413 self.inner.build().execute(executor).await
2414 }
2415 }
2416 } else if let Some(scalar_ty) = scalar_model_ty {
2417 quote! {
2418 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2419 where
2420 E: sqlx::Executor<'e, Database = #db>,
2421 {
2422 self.inner
2423 .build_query_scalar::<#scalar_ty>()
2424 .fetch_all(executor)
2425 .await
2426 }
2427
2428 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2429 where
2430 E: sqlx::Executor<'e, Database = #db>,
2431 {
2432 self.inner
2433 .build_query_scalar::<#scalar_ty>()
2434 .fetch_one(executor)
2435 .await
2436 }
2437
2438 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2439 where
2440 E: sqlx::Executor<'e, Database = #db>,
2441 {
2442 self.inner
2443 .build_query_scalar::<#scalar_ty>()
2444 .fetch_optional(executor)
2445 .await
2446 }
2447
2448 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2449 where
2450 E: sqlx::Executor<'e, Database = #db>,
2451 {
2452 self.inner.build().execute(executor).await
2453 }
2454 }
2455 } else {
2456 let model = model_opt.as_ref().unwrap();
2457 quote! {
2458 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2459 where
2460 E: sqlx::Executor<'e, Database = #db>,
2461 {
2462 self.inner.build_query_as::<#model>().fetch_all(executor).await
2463 }
2464
2465 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2466 where
2467 E: sqlx::Executor<'e, Database = #db>,
2468 {
2469 self.inner.build_query_as::<#model>().fetch_one(executor).await
2470 }
2471
2472 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2473 where
2474 E: sqlx::Executor<'e, Database = #db>,
2475 {
2476 self.inner
2477 .build_query_as::<#model>()
2478 .fetch_optional(executor)
2479 .await
2480 }
2481
2482 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2483 where
2484 E: sqlx::Executor<'e, Database = #db>,
2485 {
2486 self.inner.build().execute(executor).await
2487 }
2488 }
2489 };
2490
2491 let final_type: TokenStream2 = if let Some(model) = model_opt {
2492 if let Some(scalar_ty) = scalar_model_ty {
2493 quote! { #scalar_ty }
2494 } else {
2495 quote! { #model }
2496 }
2497 } else {
2498 quote! {}
2499 };
2500 let trait_impl = if model_opt.is_none() {
2501 quote! {
2502 impl<'args> sql_forge::SqlForgeQueryExecute
2503 for #query_ident<'args>
2504 {
2505 type Db = #db;
2506
2507 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2508 where
2509 Self: Sized + 'e,
2510 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2511 #db: 'e,
2512 {
2513 #query_ident::execute(self, executor)
2514 }
2515 }
2516 }
2517 } else {
2518 quote! {
2519 impl<'args> sql_forge::SqlForgeQuery<#final_type>
2520 for #query_ident<'args>
2521 {
2522 type Db = #db;
2523
2524 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2525 where
2526 Self: Sized + 'e,
2527 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2528 #db: 'e,
2529 {
2530 #query_ident::fetch_all(self, executor)
2531 }
2532
2533 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2534 where
2535 Self: Sized + 'e,
2536 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2537 #db: 'e,
2538 {
2539 #query_ident::fetch_one(self, executor)
2540 }
2541
2542 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2543 where
2544 Self: Sized + 'e,
2545 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2546 #db: 'e,
2547 {
2548 #query_ident::fetch_optional(self, executor)
2549 }
2550
2551 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2552 where
2553 Self: Sized + 'e,
2554 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2555 #db: 'e,
2556 {
2557 #query_ident::execute(self, executor)
2558 }
2559 }
2560 }
2561 };
2562
2563 generated_query_defs.push(quote! {
2564 struct #query_ident<'args> {
2565 inner: sqlx::QueryBuilder<'args, #db>,
2566 }
2567
2568 impl<'args> #query_ident<'args> {
2569 #exec_methods
2570 }
2571
2572 #trait_impl
2573 });
2574
2575 generated_query_values.push(quote! {
2576 #( #runtime_param_bindings )*
2577 #( #flag_bindings )*
2578 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2579 #( #runtime_steps )*
2580 let #query_value_ident = #query_ident { inner: __builder };
2581 });
2582
2583 if let Some(key) = result_key {
2584 let method_ident = format_ident!("{}", key);
2585 group_field_defs.push(quote! {
2586 #method_ident: #query_ident<'args>
2587 });
2588 group_field_tys.push(quote! { #query_ident<'args> });
2589 group_method_defs.push(quote! {
2590 pub fn #method_ident(self) -> #query_ident<'args> {
2591 self.#method_ident
2592 }
2593 });
2594
2595 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2596 group_trait_impls.push(quote! {
2597 struct #key_ty_ident;
2598
2599 impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2600 type Query = #query_ident<'args>;
2601
2602 fn get(self, _: #key_ty_ident) -> Self::Query {
2603 self.#method_ident
2604 }
2605 }
2606 });
2607 group_field_idents.push(method_ident);
2608 }
2609 }
2610
2611 let validator_tokens = quote! {
2613 let _sql_forge_validator = || {
2614 #( #validator_param_bindings )*
2615 #( #grouped_validator_invocations )*
2616 };
2617 };
2618
2619 if !is_grouped_result {
2620 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2621 return quote! {
2622 {
2623 #validator_tokens
2624 #( #generated_query_defs )*
2625 #( #generated_query_values )*
2626 #single_query_value_ident
2627 }
2628 }
2629 .into();
2630 }
2631
2632 let group_field_inits: Vec<TokenStream2> = result_cases
2633 .iter()
2634 .filter_map(|(key, _, _)| key.as_ref())
2635 .map(|key| {
2636 let method_ident = format_ident!("{}", key);
2637 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2638 quote! { #method_ident: #query_value_ident }
2639 })
2640 .collect();
2641
2642 quote! {
2643 {
2644 #validator_tokens
2645
2646 #( #generated_query_defs )*
2647 #( #generated_query_values )*
2648
2649 struct __SqlForgeQueryGroup<'args> {
2650 #( #group_field_defs, )*
2651 }
2652
2653 impl<'args> __SqlForgeQueryGroup<'args> {
2654 #( #group_method_defs )*
2655
2656 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2657 ( #( self.#group_field_idents ),* )
2658 }
2659 }
2660
2661 impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2662 type Db = #db;
2663 }
2664
2665 #( #group_trait_impls )*
2666
2667 __SqlForgeQueryGroup {
2668 #( #group_field_inits, )*
2669 }
2670 }
2671 }
2672 .into()
2673}
2674
2675#[proc_macro]
2689pub fn db_type(input: TokenStream) -> TokenStream {
2690 if !input.is_empty() {
2691 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2692 .to_compile_error()
2693 .into();
2694 }
2695
2696 match resolve_db_from_env() {
2697 Ok(db) => quote! { #db }.into(),
2698 Err(msg) => syn::Error::new(Span::call_site(), msg)
2699 .to_compile_error()
2700 .into(),
2701 }
2702}