1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::spanned::Spanned;
14use syn::{
15 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Ident, Lit, LitStr, Pat, Stmt, Token, Type,
16};
17
18mod kw {
23 syn::custom_keyword!(scalar);
24}
25
26#[derive(Clone)]
28struct ParamAssign {
29 name: Ident,
30 expr: Expr,
31}
32
33impl Parse for ParamAssign {
34 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
35 input.parse::<Token![:]>()?;
36 let name: Ident = input.parse()?;
37 input.parse::<Token![=]>()?;
38 let expr: Expr = input.parse()?;
39 Ok(Self { name, expr })
40 }
41}
42
43#[derive(Clone)]
45struct SectionFragment {
46 sql: String,
47 span: Span,
48 params: ParamsSource,
49}
50
51#[derive(Clone)]
53struct SectionMatchArm {
54 pat: Pat,
55 guard: Option<Expr>,
56 value: SectionValue,
57}
58
59#[derive(Clone)]
61enum SectionValue {
62 Single(SectionFragment),
64 Grouped(Vec<SectionValue>),
66 Match {
68 expr: Expr,
69 arms: Vec<SectionMatchArm>,
70 },
71}
72
73#[derive(Clone)]
75struct SectionAssign {
76 names: Vec<Ident>,
77 value: SectionValue,
78}
79
80impl Parse for SectionAssign {
81 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
82 input.parse::<Token![#]>()?;
83
84 let names = if input.peek(syn::token::Paren) {
86 let content;
87 syn::parenthesized!(content in input);
88 let mut out = Vec::new();
89 while !content.is_empty() {
90 out.push(content.parse::<Ident>()?);
91 if content.is_empty() {
92 break;
93 }
94 content.parse::<Token![,]>()?;
95 }
96 if out.is_empty() {
97 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
98 }
99 out
100 } else {
101 vec![input.parse::<Ident>()?]
102 };
103
104 input.parse::<Token![=]>()?;
105 let value = parse_section_value(input, names.len())?;
106 Ok(Self { names, value })
107 }
108}
109
110struct SqlForgeInput {
112 db: Option<Type>,
113 result: ResultSpec,
114 force_scalar: bool,
115 sql: SqlTemplate,
116 params: ParamsSource,
117 sections: Vec<SectionAssign>,
118 batch: Option<Expr>,
119}
120
121#[derive(Clone)]
123struct ResultAssign {
124 name: Ident,
125 model: Type,
126 force_scalar: bool,
127}
128
129#[derive(Clone)]
130enum ResultSpec {
131 None,
133 Single(Box<Type>),
135 Group(Vec<ResultAssign>),
137}
138
139#[derive(Clone)]
140enum ParamsSource {
141 None,
142 Map(Vec<ParamAssign>),
144 Struct(Box<Expr>),
146}
147
148enum SqlTemplate {
150 Literal(LitStr),
151}
152
153impl SqlTemplate {
154 fn span(&self) -> Span {
155 match self {
156 Self::Literal(lit) => lit.span(),
157 }
158 }
159
160 fn into_segments(self) -> Result<Vec<Segment>, String> {
162 match self {
163 Self::Literal(lit) => parse_literal_segments(&lit.value()),
164 }
165 }
166}
167
168fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
169 if input.peek(LitStr) {
170 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
171 } else {
172 Err(input.error("sql_forge!: SQL template must be a string literal"))
173 }
174}
175
176#[derive(Clone)]
178enum Segment {
179 Text(String),
181 Section { name: String },
183 Batch { parts: Vec<TextPart> },
185}
186
187#[derive(Clone)]
189enum TextPart {
190 Lit(String),
192 Param { name: String, is_list: bool },
194}
195
196enum MapKind {
202 Results,
203 Params,
204 Sections,
205}
206
207fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
212 let fork = input.fork();
213 let content;
214 syn::parenthesized!(content in fork);
215
216 if content.is_empty() {
217 return Err(input.error("sql_forge!: map argument cannot be empty"));
218 }
219
220 if content.peek(Token![>]) {
221 Ok(Some(MapKind::Results))
222 } else if content.peek(Token![:]) {
223 Ok(Some(MapKind::Params))
224 } else if content.peek(Token![#]) {
225 Ok(Some(MapKind::Sections))
226 } else {
227 Ok(None)
228 }
229}
230
231impl Parse for ResultAssign {
232 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
233 input.parse::<Token![>]>()?;
234 let name: Ident = input.parse()?;
235 input.parse::<Token![=]>()?;
236 let (force_scalar, model) = if input.peek(kw::scalar) {
237 input.parse::<kw::scalar>()?;
238 (true, input.parse::<Type>()?)
239 } else {
240 (false, input.parse::<Type>()?)
241 };
242 Ok(Self {
243 name,
244 model,
245 force_scalar,
246 })
247 }
248}
249
250fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
251 let content;
252 syn::parenthesized!(content in input);
253
254 let mut results = Vec::new();
255 while !content.is_empty() {
256 results.push(content.parse::<ResultAssign>()?);
257 if content.is_empty() {
258 break;
259 }
260 content.parse::<Token![,]>()?;
261 }
262
263 if results.is_empty() {
264 return Err(input.error("sql_forge!: result map cannot be empty"));
265 }
266
267 Ok(results)
268}
269
270fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
271 let content;
272 syn::parenthesized!(content in input);
273
274 let mut params = Vec::new();
275 while !content.is_empty() {
276 params.push(content.parse::<ParamAssign>()?);
277 if content.is_empty() {
278 break;
279 }
280 content.parse::<Token![,]>()?;
281 }
282
283 Ok(params)
284}
285
286fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
287 let content;
288 syn::parenthesized!(content in input);
289
290 let mut sections = Vec::new();
291 while !content.is_empty() {
292 sections.push(content.parse::<SectionAssign>()?);
293 if content.is_empty() {
294 break;
295 }
296 content.parse::<Token![,]>()?;
297 }
298
299 Ok(sections)
300}
301
302fn parse_params_source_expr(
303 input: ParseStream<'_>,
304 allow_sections: bool,
305) -> syn::Result<ParamsSource> {
306 if input.peek(syn::token::Paren) {
307 match detect_parenthesized_map_kind(input)? {
308 Some(MapKind::Results) => Err(input
309 .error("sql_forge!: result maps are only allowed as the macro result argument")),
310 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
311 Some(MapKind::Sections) if allow_sections => {
312 Err(input.error("sql_forge!: section maps are not allowed here"))
313 }
314 Some(MapKind::Sections) => Err(input.error(
315 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
316 )),
317 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
318 }
319 } else {
320 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
321 }
322}
323
324fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
325 if input.peek(syn::token::Paren) {
326 let fork = input.fork();
327 let content;
328 syn::parenthesized!(content in fork);
329
330 if let Ok(first_expr) = content.parse::<Expr>() {
331 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
332 let _ = parse_params_source_expr(&content, false)?;
333 if content.peek(Token![,]) {
334 content.parse::<Token![,]>()?;
335 }
336 if content.is_empty() {
337 let content;
338 syn::parenthesized!(content in input);
339 let first_expr: Expr = content.parse()?;
340 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
341 input.error("sql_forge!: section tuple must start with a string literal")
342 })?;
343 let span = first_expr.span();
344 content.parse::<Token![,]>()?;
345 let params = parse_params_source_expr(&content, false)?;
346 if content.peek(Token![,]) {
347 content.parse::<Token![,]>()?;
348 }
349 if !content.is_empty() {
350 return Err(content.error(
351 "sql_forge!: unexpected tokens after section-local parameter source",
352 ));
353 }
354 return Ok(SectionFragment { sql, span, params });
355 }
356 }
357 }
358 }
359
360 let expr: Expr = input.parse()?;
361 let sql = extract_lit_str(&expr).ok_or_else(|| {
362 input
363 .error("sql_forge!: section values must be string literals or (string literal, params)")
364 })?;
365 Ok(SectionFragment {
366 sql,
367 span: expr.span(),
368 params: ParamsSource::None,
369 })
370}
371
372fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
373 if input.peek(Token![match]) {
374 input.parse::<Token![match]>()?;
375 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
376 let content;
377 syn::braced!(content in input);
378 let mut arms = Vec::new();
379 while !content.is_empty() {
380 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
381 let guard = if content.peek(Token![if]) {
382 content.parse::<Token![if]>()?;
383 Some(content.parse::<Expr>()?)
384 } else {
385 None
386 };
387 content.parse::<Token![=>]>()?;
388 let value = parse_section_value(&content, width)?;
389 if content.peek(Token![,]) {
390 content.parse::<Token![,]>()?;
391 }
392 arms.push(SectionMatchArm { pat, guard, value });
393 }
394 return Ok(SectionValue::Match { expr, arms });
395 }
396
397 if width == 1 {
398 return Ok(SectionValue::Single(parse_section_fragment(input)?));
399 }
400
401 let content;
402 syn::parenthesized!(content in input);
403 let mut items = Vec::new();
404 while !content.is_empty() {
405 items.push(parse_section_value(&content, 1)?);
406 if content.is_empty() {
407 break;
408 }
409 content.parse::<Token![,]>()?;
410 }
411
412 if items.len() != width {
413 return Err(input.error(format!(
414 "sql_forge!: grouped section value must provide exactly {} items",
415 width,
416 )));
417 }
418
419 Ok(SectionValue::Grouped(items))
420}
421
422impl Parse for SqlForgeInput {
427 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
428 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
429 let sql = parse_sql_template(input)?;
430 (None, ResultSpec::None, false, sql)
431 } else if input.peek(kw::scalar) {
432 input.parse::<kw::scalar>()?;
433 let model: Type = input.parse()?;
434 input.parse::<Token![,]>()?;
435 let sql = parse_sql_template(input)?;
436 (None, ResultSpec::Single(Box::new(model)), true, sql)
437 } else if input.peek(syn::token::Paren) {
438 let result_map_kind = detect_parenthesized_map_kind(input)?;
439 match result_map_kind {
440 Some(MapKind::Results) => {
441 let result = ResultSpec::Group(parse_result_map(input)?);
442 input.parse::<Token![,]>()?;
443 let sql = parse_sql_template(input)?;
444 (None, result, false, sql)
445 }
446 _ => {
447 return Err(input.error(
448 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
449 ));
450 }
451 }
452 } else {
453 let first_ty: Type = input.parse()?;
454 input.parse::<Token![,]>()?;
455
456 if input.peek(LitStr) {
457 let model = first_ty;
458 let sql = parse_sql_template(input)?;
459 (None, ResultSpec::Single(Box::new(model)), false, sql)
460 } else if input.peek(kw::scalar) {
461 input.parse::<kw::scalar>()?;
462 let model: Type = input.parse()?;
463 input.parse::<Token![,]>()?;
464 let sql = parse_sql_template(input)?;
465 (
466 Some(first_ty),
467 ResultSpec::Single(Box::new(model)),
468 true,
469 sql,
470 )
471 } else if input.peek(syn::token::Paren)
472 && matches!(
473 detect_parenthesized_map_kind(input)?,
474 Some(MapKind::Results)
475 )
476 {
477 let result = ResultSpec::Group(parse_result_map(input)?);
478 input.parse::<Token![,]>()?;
479 let sql = parse_sql_template(input)?;
480 (Some(first_ty), result, false, sql)
481 } else {
482 let db = Some(first_ty);
483 let model: Type = input.parse()?;
484 input.parse::<Token![,]>()?;
485 let sql = parse_sql_template(input)?;
486 (db, ResultSpec::Single(Box::new(model)), false, sql)
487 }
488 };
489
490 let mut batch = None;
491 let mut params = ParamsSource::None;
492 let mut sections = Vec::new();
493 let mut seen_params = false;
494 let mut seen_sections = false;
495
496 if input.parse::<Token![,]>().is_ok() {
497 while !input.is_empty() {
498 if input.peek(Token![..]) {
499 if batch.is_some() {
500 return Err(
501 input.error("sql_forge!: only one batch source argument is allowed")
502 );
503 }
504 input.parse::<Token![..]>()?;
505 batch = Some(input.parse::<Expr>()?);
506 } else if input.peek(syn::token::Paren) {
507 match detect_parenthesized_map_kind(input)? {
508 Some(MapKind::Results) => {
509 return Err(input.error(
510 "sql_forge!: result maps are only allowed as the macro result argument",
511 ));
512 }
513 Some(MapKind::Params) => {
514 if seen_params {
515 return Err(
516 input.error("sql_forge!: only one parameter source is allowed")
517 );
518 }
519 params = ParamsSource::Map(parse_param_map(input)?);
520 seen_params = true;
521 }
522 Some(MapKind::Sections) => {
523 if seen_sections {
524 return Err(
525 input.error("sql_forge!: duplicate section map argument")
526 );
527 }
528 sections = parse_section_map(input)?;
529 seen_sections = true;
530 }
531 None => {
532 if seen_params {
533 return Err(
534 input.error("sql_forge!: only one parameter source is allowed")
535 );
536 }
537 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
538 seen_params = true;
539 }
540 }
541 } else {
542 if seen_params {
543 return Err(input.error("sql_forge!: only one parameter source is allowed"));
544 }
545 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
546 seen_params = true;
547 }
548
549 if input.parse::<Token![,]>().is_ok() {
550 continue;
551 }
552 break;
553 }
554 }
555
556 if !input.is_empty() {
557 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
558 }
559
560 Ok(Self {
561 db,
562 result,
563 force_scalar,
564 sql,
565 params,
566 sections,
567 batch,
568 })
569 }
570}
571
572fn resolve_db_from_env() -> Result<Type, String> {
577 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
578 return syn::parse_str::<Type>(&val).map_err(|err| {
579 format!(
580 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
581 val, err
582 )
583 });
584 }
585
586 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
587 Ok(d) => d,
588 Err(_) => {
589 return Err(
590 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
591 or configure [package.metadata.sql_forge] in Cargo.toml"
592 .to_string(),
593 );
594 }
595 };
596 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
597
598 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
599 format!(
600 "sql_forge!: failed to read {}: {}",
601 manifest_path.display(),
602 err
603 )
604 })?;
605
606 let value: toml::Value = toml::from_str(&cargo_toml)
607 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
608
609 let db_str = value
610 .get("package")
611 .and_then(|v| v.get("metadata"))
612 .and_then(|v| v.get("sql_forge"))
613 .and_then(|v| v.get("db"))
614 .and_then(|v| v.as_str())
615 .ok_or({
616 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
617 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
618 })?;
619
620 syn::parse_str::<Type>(db_str).map_err(|err| {
621 format!(
622 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
623 db_str, err
624 )
625 })
626}
627
628fn uses_dollar_params(db: &Type) -> bool {
629 let Type::Path(type_path) = db else {
630 return false;
631 };
632 type_path
633 .path
634 .segments
635 .last()
636 .is_some_and(|s| s.ident == "Postgres")
637}
638
639fn is_builtin_scalar_type(ty: &Type) -> bool {
640 let Type::Path(type_path) = ty else {
641 return false;
642 };
643
644 if type_path.qself.is_some()
645 || type_path.path.leading_colon.is_some()
646 || type_path.path.segments.len() != 1
647 {
648 return false;
649 }
650
651 let ident = &type_path.path.segments[0].ident;
652 ident == "i8"
653 || ident == "i16"
654 || ident == "i32"
655 || ident == "i64"
656 || ident == "isize"
657 || ident == "u8"
658 || ident == "u16"
659 || ident == "u32"
660 || ident == "u64"
661 || ident == "usize"
662 || ident == "f32"
663 || ident == "f64"
664 || ident == "bool"
665 || ident == "String"
666}
667
668fn scalar_output_type(model: &Type) -> Option<&Type> {
669 if is_builtin_scalar_type(model) {
670 return Some(model);
671 }
672 None
673}
674
675fn push_text_segment(out: &mut Vec<Segment>, text: String) {
676 if text.is_empty() {
677 return;
678 }
679 match out.last_mut() {
680 Some(Segment::Text(existing)) => existing.push_str(&text),
681 _ => out.push(Segment::Text(text)),
682 }
683}
684
685fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
686 let mut out = Vec::new();
687 let mut text = String::new();
688 let mut chars = sql.chars().peekable();
689
690 while let Some(ch) = chars.next() {
691 if ch != '{' {
692 text.push(ch);
693 continue;
694 }
695
696 if chars.peek() == Some(&'(') {
697 push_text_segment(&mut out, std::mem::take(&mut text));
698
699 let mut paren_depth = 0u32;
700 let mut content = String::new();
701 let mut found_close = false;
702 for ch in chars.by_ref() {
703 if ch == '{' {
704 return Err(
705 "sql_forge!: nested braces not allowed inside batch section".to_string()
706 );
707 }
708 if ch == '}' {
709 if paren_depth != 0 {
710 return Err(
711 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
712 .to_string(),
713 );
714 }
715 found_close = true;
716 break;
717 }
718 if ch == '(' {
719 paren_depth += 1;
720 } else if ch == ')' {
721 if paren_depth == 0 {
722 return Err(
723 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
724 .to_string(),
725 );
726 }
727 paren_depth -= 1;
728 }
729 content.push(ch);
730 }
731 if !found_close {
732 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
733 }
734 let parts = parse_text_parts(&content);
735 for part in &parts {
736 if let TextPart::Param { is_list: true, .. } = part {
737 return Err(
738 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
739 batch sections; use plain parameters (:name) instead"
740 .to_string(),
741 );
742 }
743 }
744 out.push(Segment::Batch { parts });
745 continue;
746 }
747
748 if chars.peek() != Some(&'#') {
749 text.push(ch);
750 continue;
751 }
752
753 chars.next();
754 push_text_segment(&mut out, std::mem::take(&mut text));
755
756 let mut name = String::new();
757 loop {
758 let Some(next) = chars.next() else {
759 return Err("sql_forge!: section placeholder without closing }".to_string());
760 };
761 if next == '}' {
762 break;
763 }
764 name.push(next);
765 }
766
767 if name.is_empty() {
768 return Err("sql_forge!: empty section placeholder name".to_string());
769 }
770
771 out.push(Segment::Section { name });
772 }
773
774 push_text_segment(&mut out, text);
775 Ok(out)
776}
777
778fn is_ident_start(ch: char) -> bool {
783 ch == '_' || ch.is_ascii_alphabetic()
784}
785
786fn is_ident_continue(ch: char) -> bool {
787 is_ident_start(ch) || ch.is_ascii_digit()
788}
789
790fn sanitize_backticked_alias_ident(content: &str) -> String {
791 let mut split_at = content.len();
792 for (idx, ch) in content.char_indices() {
793 if ch == '!' || ch == '?' || ch == ':' {
794 split_at = idx;
795 break;
796 }
797 }
798
799 if split_at == content.len() {
800 return content.to_string();
801 }
802
803 let base = content[..split_at].trim_end();
804 if base.is_empty() {
805 content.to_string()
806 } else {
807 base.to_string()
808 }
809}
810
811fn sanitize_runtime_sql_text(text: &str) -> String {
812 let mut out = String::with_capacity(text.len());
813 let mut chars = text.chars().peekable();
814
815 while let Some(ch) = chars.next() {
816 if ch != '`' {
817 out.push(ch);
818 continue;
819 }
820
821 let mut content = String::new();
822 let mut closed = false;
823
824 for next in chars.by_ref() {
825 if next == '`' {
826 closed = true;
827 break;
828 }
829 content.push(next);
830 }
831
832 if closed {
833 out.push('`');
834 out.push_str(&sanitize_backticked_alias_ident(&content));
835 out.push('`');
836 } else {
837 out.push('`');
838 out.push_str(&content);
839 break;
840 }
841 }
842
843 out
844}
845
846fn parse_text_parts(text: &str) -> Vec<TextPart> {
847 let mut parts = Vec::new();
848 let mut last = 0usize;
849 let mut iter = text.char_indices().peekable();
850
851 while let Some((idx, ch)) = iter.next() {
852 if ch != ':' {
853 continue;
854 }
855
856 let Some(&(next_idx, next_ch)) = iter.peek() else {
857 continue;
858 };
859
860 if !is_ident_start(next_ch) {
861 continue;
862 }
863
864 if text[..idx].ends_with(':') {
865 continue;
866 }
867
868 if last < idx {
869 parts.push(TextPart::Lit(text[last..idx].to_string()));
870 }
871
872 iter.next();
873
874 let mut name = String::new();
875 name.push(next_ch);
876 let mut end = next_idx + next_ch.len_utf8();
877
878 while let Some(&(j, c)) = iter.peek() {
879 if is_ident_continue(c) {
880 name.push(c);
881 end = j + c.len_utf8();
882 iter.next();
883 } else {
884 break;
885 }
886 }
887
888 let mut is_list = false;
889 if text[end..].starts_with("[]") {
890 is_list = true;
891 end += 2;
892 }
893
894 parts.push(TextPart::Param { name, is_list });
895 last = end;
896 }
897
898 if last < text.len() {
899 parts.push(TextPart::Lit(text[last..].to_string()));
900 }
901
902 parts
903}
904
905fn render_validator_text(
906 text: &str,
907 use_dollar_params: bool,
908 param_offset: &mut usize,
909 list_count: usize,
910) -> (String, Vec<(String, bool)>) {
911 let mut out_sql = String::new();
912 let mut occurrences = Vec::new();
913
914 for part in parse_text_parts(text) {
915 match part {
916 TextPart::Lit(lit) => out_sql.push_str(&lit),
917 TextPart::Param { name, is_list } => {
918 if is_list && list_count > 1 {
919 let slots: Vec<String> = if use_dollar_params {
920 (0..list_count)
921 .map(|i| format!("${}", *param_offset + i + 1))
922 .collect()
923 } else {
924 (0..list_count).map(|_| "?".to_string()).collect()
925 };
926 if use_dollar_params {
927 *param_offset += list_count;
928 }
929 out_sql.push_str(&slots.join(", "));
930 } else if use_dollar_params {
931 *param_offset += 1;
932 write!(out_sql, "${}", *param_offset).unwrap();
933 } else {
934 out_sql.push('?');
935 }
936 occurrences.push((name, is_list));
937 }
938 }
939 }
940
941 (out_sql, occurrences)
942}
943
944fn strip_expr(expr: &Expr) -> &Expr {
945 match expr {
946 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
947 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
948 Expr::Block(ExprBlock { block, .. }) => {
949 if block.stmts.len() != 1 {
950 return expr;
951 }
952 match &block.stmts[0] {
953 Stmt::Expr(inner, None) => strip_expr(inner),
954 _ => expr,
955 }
956 }
957 _ => expr,
958 }
959}
960
961fn extract_lit_str(expr: &Expr) -> Option<String> {
962 match strip_expr(expr) {
963 Expr::Lit(ExprLit {
964 lit: Lit::Str(lit), ..
965 }) => Some(lit.value()),
966 _ => None,
967 }
968}
969
970fn result_flag_ident(name: &str) -> syn::Ident {
975 format_ident!("__enhanced_result_flag_{}", name)
976}
977
978fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
982 fn walk(stream: TokenStream2) -> TokenStream2 {
983 let mut out = TokenStream2::new();
984 let iter = stream.into_iter().peekable();
985
986 for token in iter {
987 match token {
988 TokenTree::Group(group) => {
989 if group.delimiter() == Delimiter::Brace {
990 let mut inner = group.stream().into_iter();
991 let first = inner.next();
992 let second = inner.next();
993 let third = inner.next();
994
995 if let (
996 Some(TokenTree::Punct(p)),
997 Some(TokenTree::Ident(name_ident)),
998 None,
999 ) = (first, second, third)
1000 {
1001 if p.as_char() == '>' {
1002 let ident = result_flag_ident(&name_ident.to_string());
1003 out.extend(std::iter::once(TokenTree::Ident(ident)));
1004 continue;
1005 }
1006 }
1007 }
1008
1009 let new_inner = walk(group.stream());
1010 let mut new_group = Group::new(group.delimiter(), new_inner);
1011 new_group.set_span(group.span());
1012 out.extend(std::iter::once(TokenTree::Group(new_group)));
1013 }
1014 other => out.extend(std::iter::once(other)),
1015 }
1016 }
1017
1018 out
1019 }
1020
1021 walk(input)
1022}
1023
1024fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1025 keys.iter()
1026 .map(|key| {
1027 let ident = result_flag_ident(key);
1028 let enabled = Some(key.as_str()) == active_key;
1029 quote! { let #ident: bool = #enabled; }
1030 })
1031 .collect()
1032}
1033
1034fn transpose_section_case_matrix(
1035 case_matrix: Vec<Vec<SectionFragment>>,
1036 width: usize,
1037) -> Result<Vec<Vec<SectionFragment>>, String> {
1038 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1039
1040 for row in case_matrix {
1041 if row.len() != width {
1042 return Err(
1043 "sql_forge!: grouped sections must return one item per section".to_string(),
1044 );
1045 }
1046 for (section_idx, fragment) in row.into_iter().enumerate() {
1047 per_section[section_idx].push(fragment);
1048 }
1049 }
1050
1051 Ok(per_section)
1052}
1053
1054fn collect_section_case_matrix(
1055 value: SectionValue,
1056 width: usize,
1057 active_key: Option<&str>,
1058) -> Result<Vec<Vec<SectionFragment>>, String> {
1059 match value {
1060 SectionValue::Single(fragment) => {
1061 if width != 1 {
1062 return Err(
1063 "sql_forge!: grouped sections must return one item per section".to_string(),
1064 );
1065 }
1066 Ok(vec![vec![fragment]])
1067 }
1068 SectionValue::Grouped(values) => {
1069 if values.len() != width {
1070 return Err(
1071 "sql_forge!: grouped sections must return one item per section".to_string(),
1072 );
1073 }
1074
1075 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1076 let mut nmax = 1usize;
1077
1078 for value in values {
1079 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1080 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1081 for mut row in item_matrix {
1082 let fragment = row.pop().ok_or_else(|| {
1083 "sql_forge!: grouped sections must return one item per section".to_string()
1084 })?;
1085 if !row.is_empty() {
1086 return Err(
1087 "sql_forge!: grouped sections must return one item per section"
1088 .to_string(),
1089 );
1090 }
1091 item_variants.push(fragment);
1092 }
1093 if item_variants.is_empty() {
1094 return Err("sql_forge!: section match must have at least one arm".to_string());
1095 }
1096 nmax = nmax.max(item_variants.len());
1097 variants_by_section.push(item_variants);
1098 }
1099
1100 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1101 for case_idx in 0..nmax {
1102 let mut row = Vec::<SectionFragment>::with_capacity(width);
1103 for variants in &variants_by_section {
1104 row.push(variants[case_idx % variants.len()].clone());
1105 }
1106 case_matrix.push(row);
1107 }
1108
1109 Ok(case_matrix)
1110 }
1111 SectionValue::Match { expr, arms } => {
1112 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1113
1114 if let Some(key) = expr_result_flag_key(&expr) {
1115 let target = active_key == Some(key.as_str());
1116 for arm in arms {
1117 if arm.guard.is_none() {
1118 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1119 continue;
1120 }
1121 }
1122 case_matrix.extend(collect_section_case_matrix(arm.value, width, active_key)?);
1123 }
1124 } else {
1125 for arm in arms {
1126 case_matrix.extend(collect_section_case_matrix(arm.value, width, active_key)?);
1127 }
1128 }
1129
1130 if case_matrix.is_empty() {
1131 return Err("sql_forge!: section match must have at least one arm".to_string());
1132 }
1133
1134 Ok(case_matrix)
1135 }
1136 }
1137}
1138
1139fn collect_section_variants(
1149 value: SectionValue,
1150 width: usize,
1151) -> Result<Vec<Vec<SectionFragment>>, String> {
1152 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1153}
1154
1155fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1156 match strip_expr(expr) {
1157 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1158 let name = path.path.segments[0].ident.to_string();
1159 name.strip_prefix("__enhanced_result_flag_")
1160 .map(|v| v.to_string())
1161 }
1162 _ => None,
1163 }
1164}
1165
1166fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1167 match pat {
1168 Pat::Lit(expr_lit) => match &expr_lit.lit {
1169 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1170 _ => None,
1171 },
1172 Pat::Wild(_) => Some(true),
1173 _ => None,
1174 }
1175}
1176
1177fn collect_section_variants_for_result(
1182 value: SectionValue,
1183 width: usize,
1184 active_key: Option<&str>,
1185) -> Result<Vec<Vec<SectionFragment>>, String> {
1186 transpose_section_case_matrix(
1187 collect_section_case_matrix(value, width, active_key)?,
1188 width,
1189 )
1190}
1191
1192fn build_param_bindings(
1200 params: &ParamsSource,
1201 used_param_names: &[String],
1202 prefix: &str,
1203 for_validator: bool,
1204 enforce_usage_check: bool,
1205) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1206 let mut declared_params = HashMap::<String, syn::Ident>::new();
1207 let mut bindings = Vec::<TokenStream2>::new();
1208
1209 match params {
1210 ParamsSource::None => {}
1211 ParamsSource::Map(entries) => {
1212 for entry in entries {
1213 let key = entry.name.to_string();
1214 if declared_params.contains_key(&key) {
1215 return Err(syn::Error::new(
1216 entry.name.span(),
1217 "sql_forge!: duplicated parameter mapping",
1218 )
1219 .to_compile_error()
1220 .into());
1221 }
1222 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1223 return Err(syn::Error::new(
1224 entry.name.span(),
1225 format!(
1226 "sql_forge!: parameter :{} is unused in the SQL template",
1227 key,
1228 ),
1229 )
1230 .to_compile_error()
1231 .into());
1232 }
1233 let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1234 let expr = &entry.expr;
1235 if for_validator {
1236 bindings.push(quote! {
1237 let #local_ident = &(#expr);
1238 });
1239 } else {
1240 bindings.push(quote! {
1241 let #local_ident = #expr;
1242 });
1243 }
1244 declared_params.insert(key, local_ident);
1245 }
1246 }
1247 ParamsSource::Struct(expr) => {
1248 let source_ident = format_ident!("__enhanced_source_{}", prefix);
1249 bindings.push(quote! {
1250 let #source_ident = &(#expr);
1251 });
1252 for name in used_param_names {
1253 let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1254 let field_ident = format_ident!("{}", name);
1255 bindings.push(quote! {
1256 let #local_ident = #source_ident.#field_ident;
1257 });
1258 declared_params.insert(name.to_string(), local_ident);
1259 }
1260 }
1261 }
1262
1263 Ok((declared_params, bindings))
1264}
1265
1266struct ValidatorRenderContext<'a> {
1267 local_params: &'a HashMap<String, syn::Ident>,
1268 top_level_params: &'a HashMap<String, syn::Ident>,
1269 allow_top_level_fallback: bool,
1270 use_dollar_params: bool,
1271 sql_span: Span,
1272 list_count: usize,
1273}
1274
1275fn render_validator_args(
1280 sql: &str,
1281 param_offset: &mut usize,
1282 context: &ValidatorRenderContext<'_>,
1283) -> Result<(String, Vec<TokenStream2>), TokenStream> {
1284 let (rendered_sql, occurrences) = render_validator_text(
1285 sql,
1286 context.use_dollar_params,
1287 param_offset,
1288 context.list_count,
1289 );
1290 let mut args = Vec::<TokenStream2>::new();
1291
1292 for (name, is_list) in occurrences {
1293 let local_ident = if context.allow_top_level_fallback {
1294 context
1295 .local_params
1296 .get(&name)
1297 .or_else(|| context.top_level_params.get(&name))
1298 } else {
1299 context.local_params.get(&name)
1300 };
1301
1302 let Some(local_ident) = local_ident else {
1303 return Err(syn::Error::new(
1304 context.sql_span,
1305 format!("sql_forge!: parameter :{} has no mapping", name),
1306 )
1307 .to_compile_error()
1308 .into());
1309 };
1310
1311 if is_list {
1312 for _ in 0..context.list_count {
1313 args.push(quote! { *(#local_ident).as_slice().first().unwrap_or(&0i64) });
1314 }
1315 } else {
1316 args.push(quote! { #local_ident });
1317 }
1318 }
1319
1320 Ok((rendered_sql, args))
1321}
1322
1323fn render_runtime_fragment(
1330 fragment: &SectionFragment,
1331 local_params: &HashMap<String, syn::Ident>,
1332) -> Result<TokenStream2, TokenStream> {
1333 let mut steps = Vec::<TokenStream2>::new();
1334
1335 for part in parse_text_parts(&fragment.sql) {
1336 match part {
1337 TextPart::Lit(lit) => {
1338 let lit_str = LitStr::new(&lit, fragment.span);
1339 steps.push(quote! { __builder.push(#lit_str); });
1340 }
1341 TextPart::Param { name, is_list } => {
1342 let Some(local_ident) = local_params.get(&name) else {
1343 return Err(syn::Error::new(
1344 fragment.span,
1345 format!("sql_forge!: parameter :{} has no mapping", name),
1346 )
1347 .to_compile_error()
1348 .into());
1349 };
1350
1351 if is_list {
1352 steps.push(quote! {
1353 let __enhanced_values = #local_ident;
1354 let mut __separated = __builder.separated(", ");
1355 for __value in __enhanced_values {
1356 __separated.push_bind(__value);
1357 }
1358 });
1359 } else {
1360 steps.push(quote! {
1361 __builder.push_bind(#local_ident);
1362 });
1363 }
1364 }
1365 }
1366 }
1367
1368 Ok(quote! { #( #steps )* })
1369}
1370
1371fn build_section_runtime_action(
1372 value: &SectionValue,
1373 section_idx: usize,
1374 prefix: &str,
1375) -> Result<TokenStream2, TokenStream> {
1376 match value {
1377 SectionValue::Single(fragment) => {
1378 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1379 let (local_params, bindings) =
1380 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1381 let body = render_runtime_fragment(fragment, &local_params)?;
1382 Ok(quote! {{ #( #bindings )* #body }})
1383 }
1384 SectionValue::Grouped(fragments) => build_section_runtime_action(
1385 &fragments[section_idx],
1386 0,
1387 &format!("{}_grouped_{}", prefix, section_idx),
1388 ),
1389 SectionValue::Match { expr, arms } => {
1390 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1391 .iter()
1392 .enumerate()
1393 .map(|(arm_idx, arm)| {
1394 let pat = &arm.pat;
1395 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1396 let body = build_section_runtime_action(
1397 &arm.value,
1398 section_idx,
1399 &format!("{}_{}", prefix, arm_idx),
1400 )?;
1401 Ok::<TokenStream2, TokenStream>(quote! { #pat #guard_tokens => #body })
1402 })
1403 .collect();
1404 let arm_tokens = arm_tokens?;
1405 Ok(quote! {
1406 match #expr {
1407 #( #arm_tokens ),*
1408 }
1409 })
1410 }
1411 }
1412}
1413
1414fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1415 let mut names = Vec::new();
1416 let mut seen = HashSet::<String>::new();
1417
1418 for segment in segments {
1419 match segment {
1420 Segment::Text(text) => {
1421 for name in collect_used_param_names_in_sql(text) {
1422 if seen.insert(name.clone()) {
1423 names.push(name);
1424 }
1425 }
1426 }
1427 Segment::Batch { parts } => {
1428 for part in parts {
1429 if let TextPart::Param { name, .. } = part {
1430 if seen.insert(name.clone()) {
1431 names.push(name.clone());
1432 }
1433 }
1434 }
1435 }
1436 _ => {}
1437 }
1438 }
1439
1440 names
1441}
1442
1443fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1444 let mut names = Vec::new();
1445 let mut seen = HashSet::<String>::new();
1446 for part in parse_text_parts(sql) {
1447 if let TextPart::Param { name, .. } = part {
1448 if seen.insert(name.to_string()) {
1449 names.push(name);
1450 }
1451 }
1452 }
1453 names
1454}
1455
1456#[proc_macro]
1707#[allow(clippy::too_many_lines)]
1708pub fn sql_forge(input: TokenStream) -> TokenStream {
1709 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1711 let SqlForgeInput {
1712 db,
1713 result,
1714 force_scalar,
1715 sql,
1716 params,
1717 sections,
1718 batch,
1719 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1720 Ok(v) => v,
1721 Err(err) => return err.to_compile_error().into(),
1722 };
1723
1724 let db = match db {
1726 Some(db) => db,
1727 None => match resolve_db_from_env() {
1728 Ok(db) => db,
1729 Err(msg) => {
1730 return syn::Error::new(Span::call_site(), msg)
1731 .to_compile_error()
1732 .into();
1733 }
1734 },
1735 };
1736
1737 let use_dollar_params = uses_dollar_params(&db);
1738 let is_sqlite = if let syn::Type::Path(type_path) = &db {
1739 type_path
1740 .path
1741 .segments
1742 .last()
1743 .is_some_and(|s| s.ident == "Sqlite")
1744 } else {
1745 false
1746 };
1747 let list_count: usize = if is_sqlite { 1 } else { 3 };
1748
1749 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1753 ResultSpec::None => {
1754 vec![(None, None, None)]
1755 }
1756 ResultSpec::Single(ref model) => {
1757 let model_ty = (**model).clone();
1758 let scalar = if force_scalar {
1759 Some(model_ty.clone())
1760 } else {
1761 scalar_output_type(model.as_ref()).cloned()
1762 };
1763 vec![(None, Some(model_ty), scalar)]
1764 }
1765 ResultSpec::Group(ref cases) => {
1766 if force_scalar {
1767 return syn::Error::new(
1768 Span::call_site(),
1769 "sql_forge!: scalar mode is not supported for grouped result maps",
1770 )
1771 .to_compile_error()
1772 .into();
1773 }
1774
1775 let mut out = Vec::new();
1776 let mut seen = HashSet::new();
1777 for case in cases {
1778 let key = case.name.to_string();
1779 if !seen.insert(key.clone()) {
1780 return syn::Error::new(
1781 case.name.span(),
1782 "sql_forge!: duplicated key in result map",
1783 )
1784 .to_compile_error()
1785 .into();
1786 }
1787
1788 let model = case.model.clone();
1789 let scalar = if case.force_scalar {
1790 Some(model.clone())
1791 } else {
1792 scalar_output_type(&case.model).cloned()
1793 };
1794 out.push((Some(key), Some(model), scalar));
1795 }
1796 out
1797 }
1798 };
1799 let group_result_keys: Vec<String> = result_cases
1800 .iter()
1801 .filter_map(|(key, _, _)| key.as_ref().cloned())
1802 .collect();
1803 let is_grouped_result = !group_result_keys.is_empty();
1804 let sql_span = sql.span();
1805
1806 let segments = match sql.into_segments() {
1808 Ok(segments) => segments,
1809 Err(msg) => {
1810 return syn::Error::new(sql_span, msg).to_compile_error().into();
1811 }
1812 };
1813
1814 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
1815 match (&batch, has_batch_segment) {
1816 (None, true) => {
1817 return syn::Error::new(
1818 sql_span,
1819 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
1820 was provided"
1821 )
1822 .to_compile_error()
1823 .into();
1824 }
1825 (Some(_), false) => {
1826 return syn::Error::new(
1827 sql_span,
1828 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
1829 batch section",
1830 )
1831 .to_compile_error()
1832 .into();
1833 }
1834 _ => {}
1835 }
1836
1837 let used_param_names = collect_used_param_names(&segments);
1838
1839 let batch_param_names: std::collections::HashSet<String> = segments
1844 .iter()
1845 .filter_map(|s| {
1846 if let Segment::Batch { parts } = s {
1847 Some(parts.iter().filter_map(|p| {
1848 if let TextPart::Param { name, .. } = p {
1849 Some(name.clone())
1850 } else {
1851 None
1852 }
1853 }))
1854 } else {
1855 None
1856 }
1857 })
1858 .flatten()
1859 .collect();
1860 let top_level_used_names: Vec<String> = used_param_names
1861 .iter()
1862 .filter(|n| !batch_param_names.contains(*n))
1863 .cloned()
1864 .collect();
1865
1866 let (declared_params, validator_param_bindings) =
1868 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
1869 Ok(v) => v,
1870 Err(err) => return err,
1871 };
1872
1873 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
1874
1875 for assign in §ions {
1877 let SectionAssign { names, value } = assign;
1878
1879 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
1881 for (section_idx, name_ident) in names.iter().enumerate() {
1882 let name = name_ident.to_string();
1883 if runtime_section_actions.contains_key(&name) {
1884 return syn::Error::new(
1885 name_ident.span(),
1886 "sql_forge!: duplicated section mapping",
1887 )
1888 .to_compile_error()
1889 .into();
1890 }
1891 let action = match build_section_runtime_action(
1892 value,
1893 section_idx,
1894 &format!("section_{}", name),
1895 ) {
1896 Ok(action) => action,
1897 Err(err) => return err,
1898 };
1899 named_actions.push((name, action));
1900 }
1901
1902 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
1904 return syn::Error::new(names[0].span(), msg)
1905 .to_compile_error()
1906 .into();
1907 }
1908
1909 for (name, action) in named_actions {
1910 runtime_section_actions.insert(name, action);
1911 }
1912 }
1913
1914 let sql_section_names: std::collections::HashSet<&str> = segments
1915 .iter()
1916 .filter_map(|seg| {
1917 if let Segment::Section { name } = seg {
1918 Some(name.as_str())
1919 } else {
1920 None
1921 }
1922 })
1923 .collect();
1924 for name in runtime_section_actions.keys() {
1925 if !sql_section_names.contains(name.as_str()) {
1926 return syn::Error::new(
1927 sql_span,
1928 format!(
1929 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
1930 name, name,
1931 ),
1932 )
1933 .to_compile_error()
1934 .into();
1935 }
1936 }
1937
1938 let mut generated_query_defs = Vec::<TokenStream2>::new();
1940 let mut generated_query_values = Vec::<TokenStream2>::new();
1941 let mut group_field_defs = Vec::<TokenStream2>::new();
1942 let mut group_method_defs = Vec::<TokenStream2>::new();
1943 let mut group_field_idents = Vec::<syn::Ident>::new();
1944 let mut group_field_tys = Vec::<TokenStream2>::new();
1945 let mut group_trait_impls = Vec::<TokenStream2>::new();
1946
1947 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
1948
1949 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
1950 let suffix = result_key.as_deref().unwrap_or("single");
1951 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
1952 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
1953
1954 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
1955
1956 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
1957 for assign in §ions {
1958 let SectionAssign { names, value } = assign;
1959 let variants_by_section = match collect_section_variants_for_result(
1960 value.clone(),
1961 names.len(),
1962 result_key.as_deref(),
1963 ) {
1964 Ok(v) => v,
1965 Err(msg) => {
1966 return syn::Error::new(names[0].span(), msg)
1967 .to_compile_error()
1968 .into();
1969 }
1970 };
1971
1972 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
1973 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
1974 }
1975 }
1976
1977 let mut nmax = 1usize;
1978 for segment in &segments {
1979 if let Segment::Section { name } = segment {
1980 if let Some(variants) = section_variants_for_validation.get(name) {
1981 if variants.is_empty() {
1982 return syn::Error::new(
1983 sql_span,
1984 format!("sql_forge!: section {{#{}}} has no possible variants", name),
1985 )
1986 .to_compile_error()
1987 .into();
1988 }
1989 nmax = nmax.max(variants.len());
1990 } else {
1991 return syn::Error::new(
1992 sql_span,
1993 format!("sql_forge!: section {{#{}}} has no mapping", name),
1994 )
1995 .to_compile_error()
1996 .into();
1997 }
1998 }
1999 }
2000
2001 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2002 for case_idx in 0..nmax {
2003 let mut sql_case = String::new();
2004 let mut case_setup = Vec::<TokenStream2>::new();
2005 let mut case_args = Vec::<TokenStream2>::new();
2006 let mut param_offset = 0usize;
2007 let empty_params = HashMap::<String, syn::Ident>::new();
2008 let root_validator_context = ValidatorRenderContext {
2009 local_params: &empty_params,
2010 top_level_params: &declared_params,
2011 allow_top_level_fallback: true,
2012 use_dollar_params,
2013 sql_span,
2014 list_count,
2015 };
2016
2017 for segment in &segments {
2018 match segment {
2019 Segment::Text(text) => {
2020 let (chunk_sql, chunk_args) = match render_validator_args(
2021 text,
2022 &mut param_offset,
2023 &root_validator_context,
2024 ) {
2025 Ok(value) => value,
2026 Err(err) => return err,
2027 };
2028 sql_case.push_str(&chunk_sql);
2029 case_args.extend(chunk_args);
2030 }
2031 Segment::Section { name } => {
2032 let Some(variants) = section_variants_for_validation.get(name) else {
2033 return syn::Error::new(
2034 sql_span,
2035 format!("sql_forge!: section {{#{}}} has no mapping", name),
2036 )
2037 .to_compile_error()
2038 .into();
2039 };
2040
2041 let fragment = &variants[case_idx % variants.len()];
2042 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2043 let (local_params, bindings) = match build_param_bindings(
2044 &fragment.params,
2045 &used_param_names,
2046 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2047 true,
2048 true,
2049 ) {
2050 Ok(value) => value,
2051 Err(err) => return err,
2052 };
2053 let section_validator_context = ValidatorRenderContext {
2054 local_params: &local_params,
2055 top_level_params: &declared_params,
2056 allow_top_level_fallback: false,
2057 use_dollar_params,
2058 sql_span: fragment.span,
2059 list_count,
2060 };
2061 let (chunk_sql, chunk_args) = match render_validator_args(
2062 &fragment.sql,
2063 &mut param_offset,
2064 §ion_validator_context,
2065 ) {
2066 Ok(value) => value,
2067 Err(err) => return err,
2068 };
2069 sql_case.push_str(&chunk_sql);
2070 case_setup.extend(bindings);
2071 case_args.extend(chunk_args);
2072 }
2073 Segment::Batch { parts } => {
2074 let mut first = true;
2075 for _ in 0..list_count {
2076 let sep = if first { "" } else { ", " };
2077 first = false;
2078 sql_case.push_str(sep);
2079 for tp in parts {
2080 match tp {
2081 TextPart::Lit(lit) => sql_case.push_str(lit),
2082 TextPart::Param { name, .. } => {
2083 if let Some(batch_expr) = &batch {
2084 let field_ident = format_ident!("{}", name);
2085 if use_dollar_params {
2086 param_offset += 1;
2087 write!(sql_case, "${}", param_offset).unwrap();
2088 } else {
2089 sql_case.push('?');
2090 }
2091 case_args.push(quote! { #batch_expr[0].#field_ident });
2092 } else if use_dollar_params {
2093 param_offset += 1;
2094 write!(sql_case, "${}", param_offset).unwrap();
2095 } else {
2096 sql_case.push('?');
2097 }
2098 }
2099 }
2100 }
2101 }
2102 }
2103 }
2104 }
2105
2106 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2107 }
2108
2109 let mut validator_invocations = Vec::<TokenStream2>::new();
2110 for (sql_lit, case_setup, args) in &validator_cases {
2111 if model_opt.is_none() {
2112 if args.is_empty() {
2113 validator_invocations.push(quote! {
2114 {
2115 #( #case_setup )*
2116 let _ = sqlx::query_scalar!(
2117 #sql_lit,
2118 );
2119 }
2120 });
2121 } else {
2122 validator_invocations.push(quote! {
2123 {
2124 #( #case_setup )*
2125 let _ = sqlx::query_scalar!(
2126 #sql_lit,
2127 #( #args ),*
2128 );
2129 }
2130 });
2131 }
2132 } else if let Some(scalar_ty) = scalar_model_ty {
2133 if args.is_empty() {
2134 validator_invocations.push(quote! {
2135 {
2136 #( #case_setup )*
2137 let _ = sqlx::query_scalar!(
2138 #sql_lit,
2139 );
2140 }
2141 });
2142 } else {
2143 validator_invocations.push(quote! {
2144 {
2145 #( #case_setup )*
2146 let _ = sqlx::query_scalar!(
2147 #sql_lit,
2148 #( #args ),*
2149 );
2150 }
2151 });
2152 }
2153 let _ = scalar_ty;
2154 } else if args.is_empty() {
2155 validator_invocations.push(quote! {
2156 {
2157 #( #case_setup )*
2158 let _ = sqlx::query_as!(
2159 __EnhancedModel,
2160 #sql_lit,
2161 );
2162 }
2163 });
2164 } else {
2165 validator_invocations.push(quote! {
2166 {
2167 #( #case_setup )*
2168 let _ = sqlx::query_as!(
2169 __EnhancedModel,
2170 #sql_lit,
2171 #( #args ),*
2172 );
2173 }
2174 });
2175 }
2176 }
2177
2178 let model_alias = if let Some(model) = model_opt {
2179 if scalar_model_ty.is_none() {
2180 quote! { type __EnhancedModel = #model; }
2181 } else {
2182 quote! {}
2183 }
2184 } else {
2185 quote! {}
2186 };
2187 grouped_validator_invocations.push(quote! {
2188 {
2189 #( #flag_bindings )*
2190 #model_alias
2191 #( #validator_invocations )*
2192 }
2193 });
2194
2195 let (runtime_declared_params, runtime_param_bindings) =
2196 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2197 Ok(v) => v,
2198 Err(err) => return err,
2199 };
2200
2201 let mut runtime_steps = Vec::<TokenStream2>::new();
2202 for (seg_idx, segment) in segments.iter().enumerate() {
2203 match segment {
2204 Segment::Text(text) => {
2205 for part in parse_text_parts(text) {
2206 match part {
2207 TextPart::Lit(lit) => {
2208 let lit = sanitize_runtime_sql_text(&lit);
2209 let lit_str = LitStr::new(&lit, sql_span);
2210 runtime_steps.push(quote! {
2211 __builder.push(#lit_str);
2212 });
2213 }
2214 TextPart::Param { name, is_list } => {
2215 let Some(local_ident) = runtime_declared_params.get(&name) else {
2216 return syn::Error::new(
2217 sql_span,
2218 format!("sql_forge!: parameter :{} has no mapping", name),
2219 )
2220 .to_compile_error()
2221 .into();
2222 };
2223
2224 if is_list {
2225 runtime_steps.push(quote! {
2226 let __enhanced_values = #local_ident;
2227 let mut __separated = __builder.separated(", ");
2228 for __value in __enhanced_values {
2229 __separated.push_bind(__value);
2230 }
2231 });
2232 } else {
2233 runtime_steps.push(quote! {
2234 __builder.push_bind(#local_ident);
2235 });
2236 }
2237 }
2238 }
2239 }
2240 }
2241 Segment::Section { name } => {
2242 let Some(section_action) = runtime_section_actions.get(name) else {
2243 let _ = seg_idx;
2244 return syn::Error::new(
2245 sql_span,
2246 format!("sql_forge!: section {{#{}}} has no mapping", name),
2247 )
2248 .to_compile_error()
2249 .into();
2250 };
2251 runtime_steps.push(quote! {
2252 #section_action
2253 });
2254 }
2255 Segment::Batch { parts } => {
2256 if let Some(batch_expr) = &batch {
2257 let mut body = Vec::<TokenStream2>::new();
2258 for part in parts {
2259 match part {
2260 TextPart::Lit(lit) => {
2261 let lit_str = LitStr::new(lit, sql_span);
2262 body.push(quote! {
2263 __builder.push(#lit_str);
2264 });
2265 }
2266 TextPart::Param { name, .. } => {
2267 let field_ident = format_ident!("{}", name);
2268 body.push(quote! {
2269 __builder.push_bind(__item.#field_ident);
2270 });
2271 }
2272 }
2273 }
2274 runtime_steps.push(quote! {
2275 {
2276 let mut __first = true;
2277 for __item in #batch_expr {
2278 if !__first {
2279 __builder.push(", ");
2280 }
2281 __first = false;
2282 #( #body )*
2283 }
2284 }
2285 });
2286 }
2287 }
2288 }
2289 }
2290
2291 let exec_methods = if model_opt.is_none() {
2292 quote! {
2293 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2294 where
2295 E: sqlx::Executor<'e, Database = #db>,
2296 {
2297 self.inner.build().execute(executor).await
2298 }
2299 }
2300 } else if let Some(scalar_ty) = scalar_model_ty {
2301 quote! {
2302 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2303 where
2304 E: sqlx::Executor<'e, Database = #db>,
2305 {
2306 self.inner
2307 .build_query_scalar::<#scalar_ty>()
2308 .fetch_all(executor)
2309 .await
2310 }
2311
2312 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2313 where
2314 E: sqlx::Executor<'e, Database = #db>,
2315 {
2316 self.inner
2317 .build_query_scalar::<#scalar_ty>()
2318 .fetch_one(executor)
2319 .await
2320 }
2321
2322 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2323 where
2324 E: sqlx::Executor<'e, Database = #db>,
2325 {
2326 self.inner
2327 .build_query_scalar::<#scalar_ty>()
2328 .fetch_optional(executor)
2329 .await
2330 }
2331
2332 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2333 where
2334 E: sqlx::Executor<'e, Database = #db>,
2335 {
2336 self.inner.build().execute(executor).await
2337 }
2338 }
2339 } else {
2340 let model = model_opt.as_ref().unwrap();
2341 quote! {
2342 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2343 where
2344 E: sqlx::Executor<'e, Database = #db>,
2345 {
2346 self.inner.build_query_as::<#model>().fetch_all(executor).await
2347 }
2348
2349 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2350 where
2351 E: sqlx::Executor<'e, Database = #db>,
2352 {
2353 self.inner.build_query_as::<#model>().fetch_one(executor).await
2354 }
2355
2356 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2357 where
2358 E: sqlx::Executor<'e, Database = #db>,
2359 {
2360 self.inner
2361 .build_query_as::<#model>()
2362 .fetch_optional(executor)
2363 .await
2364 }
2365
2366 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2367 where
2368 E: sqlx::Executor<'e, Database = #db>,
2369 {
2370 self.inner.build().execute(executor).await
2371 }
2372 }
2373 };
2374
2375 let final_type: TokenStream2 = if let Some(model) = model_opt {
2376 if let Some(scalar_ty) = scalar_model_ty {
2377 quote! { #scalar_ty }
2378 } else {
2379 quote! { #model }
2380 }
2381 } else {
2382 quote! {}
2383 };
2384 let trait_impl = if model_opt.is_none() {
2385 quote! {
2386 impl<'args> sql_forge::SqlForgeQueryExecute
2387 for #query_ident<'args>
2388 {
2389 type Db = #db;
2390
2391 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2392 where
2393 Self: Sized + 'e,
2394 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2395 #db: 'e,
2396 {
2397 #query_ident::execute(self, executor)
2398 }
2399 }
2400 }
2401 } else {
2402 quote! {
2403 impl<'args> sql_forge::SqlForgeQuery<#final_type>
2404 for #query_ident<'args>
2405 {
2406 type Db = #db;
2407
2408 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2409 where
2410 Self: Sized + 'e,
2411 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2412 #db: 'e,
2413 {
2414 #query_ident::fetch_all(self, executor)
2415 }
2416
2417 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2418 where
2419 Self: Sized + 'e,
2420 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2421 #db: 'e,
2422 {
2423 #query_ident::fetch_one(self, executor)
2424 }
2425
2426 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2427 where
2428 Self: Sized + 'e,
2429 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2430 #db: 'e,
2431 {
2432 #query_ident::fetch_optional(self, executor)
2433 }
2434
2435 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2436 where
2437 Self: Sized + 'e,
2438 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2439 #db: 'e,
2440 {
2441 #query_ident::execute(self, executor)
2442 }
2443 }
2444 }
2445 };
2446
2447 generated_query_defs.push(quote! {
2448 struct #query_ident<'args> {
2449 inner: sqlx::QueryBuilder<'args, #db>,
2450 }
2451
2452 impl<'args> #query_ident<'args> {
2453 #exec_methods
2454 }
2455
2456 #trait_impl
2457 });
2458
2459 generated_query_values.push(quote! {
2460 #( #runtime_param_bindings )*
2461 #( #flag_bindings )*
2462 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2463 #( #runtime_steps )*
2464 let #query_value_ident = #query_ident { inner: __builder };
2465 });
2466
2467 if let Some(key) = result_key {
2468 let method_ident = format_ident!("{}", key);
2469 group_field_defs.push(quote! {
2470 #method_ident: #query_ident<'args>
2471 });
2472 group_field_tys.push(quote! { #query_ident<'args> });
2473 group_method_defs.push(quote! {
2474 pub fn #method_ident(self) -> #query_ident<'args> {
2475 self.#method_ident
2476 }
2477 });
2478
2479 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2480 group_trait_impls.push(quote! {
2481 struct #key_ty_ident;
2482
2483 impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2484 type Query = #query_ident<'args>;
2485
2486 fn get(self, _: #key_ty_ident) -> Self::Query {
2487 self.#method_ident
2488 }
2489 }
2490 });
2491 group_field_idents.push(method_ident);
2492 }
2493 }
2494
2495 let validator_tokens = quote! {
2497 let _sql_forge_validator = || {
2498 #( #validator_param_bindings )*
2499 #( #grouped_validator_invocations )*
2500 };
2501 };
2502
2503 if !is_grouped_result {
2504 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2505 return quote! {
2506 {
2507 #validator_tokens
2508 #( #generated_query_defs )*
2509 #( #generated_query_values )*
2510 #single_query_value_ident
2511 }
2512 }
2513 .into();
2514 }
2515
2516 let group_field_inits: Vec<TokenStream2> = result_cases
2517 .iter()
2518 .filter_map(|(key, _, _)| key.as_ref())
2519 .map(|key| {
2520 let method_ident = format_ident!("{}", key);
2521 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2522 quote! { #method_ident: #query_value_ident }
2523 })
2524 .collect();
2525
2526 quote! {
2527 {
2528 #validator_tokens
2529
2530 #( #generated_query_defs )*
2531 #( #generated_query_values )*
2532
2533 struct __SqlForgeQueryGroup<'args> {
2534 #( #group_field_defs, )*
2535 }
2536
2537 impl<'args> __SqlForgeQueryGroup<'args> {
2538 #( #group_method_defs )*
2539
2540 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2541 ( #( self.#group_field_idents ),* )
2542 }
2543 }
2544
2545 impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2546 type Db = #db;
2547 }
2548
2549 #( #group_trait_impls )*
2550
2551 __SqlForgeQueryGroup {
2552 #( #group_field_inits, )*
2553 }
2554 }
2555 }
2556 .into()
2557}
2558
2559#[proc_macro]
2573pub fn db_type(input: TokenStream) -> TokenStream {
2574 if !input.is_empty() {
2575 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2576 .to_compile_error()
2577 .into();
2578 }
2579
2580 match resolve_db_from_env() {
2581 Ok(db) => quote! { #db }.into(),
2582 Err(msg) => syn::Error::new(Span::call_site(), msg)
2583 .to_compile_error()
2584 .into(),
2585 }
2586}