1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17 Stmt, Token, Type,
18};
19
20mod kw {
25 syn::custom_keyword!(scalar);
26}
27
28#[derive(Clone)]
30struct ParamAssign {
31 name: Ident,
32 expr: Expr,
33}
34
35impl Parse for ParamAssign {
36 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37 input.parse::<Token![:]>()?;
38 let name: Ident = input.parse()?;
39 input.parse::<Token![=]>()?;
40 let expr: Expr = input.parse()?;
41 Ok(Self { name, expr })
42 }
43}
44
45#[derive(Clone)]
47struct SectionFragment {
48 sql: String,
49 span: Span,
50 params: ParamsSource,
51}
52
53#[derive(Clone)]
55struct SectionMatchArm {
56 pat: Pat,
57 guard: Option<Expr>,
58 value: SectionValue,
59}
60
61#[derive(Clone)]
63enum SectionValue {
64 Single(SectionFragment),
66 Grouped(Vec<SectionValue>),
68 Match {
70 expr: Expr,
71 arms: Vec<SectionMatchArm>,
72 },
73}
74
75#[derive(Clone)]
77struct SectionAssign {
78 names: Vec<Ident>,
79 value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 input.parse::<Token![#]>()?;
85
86 let names = if input.peek(syn::token::Paren) {
88 let content;
89 syn::parenthesized!(content in input);
90 let mut out = Vec::new();
91 while !content.is_empty() {
92 out.push(content.parse::<Ident>()?);
93 if content.is_empty() {
94 break;
95 }
96 content.parse::<Token![,]>()?;
97 }
98 if out.is_empty() {
99 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100 }
101 out
102 } else {
103 vec![input.parse::<Ident>()?]
104 };
105
106 input.parse::<Token![=]>()?;
107 let value = parse_section_value(input, names.len())?;
108 Ok(Self { names, value })
109 }
110}
111
112struct SqlForgeInput {
114 db: Option<Type>,
115 result: ResultSpec,
116 force_scalar: bool,
117 sql: SqlTemplate,
118 params: ParamsSource,
119 sections: Vec<SectionAssign>,
120 batch: Option<Expr>,
121}
122
123#[derive(Clone)]
125struct ResultAssign {
126 name: Ident,
127 model: Type,
128 force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133 None,
135 Single(Box<Type>),
137 Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143 None,
144 Map(Vec<ParamAssign>),
146 Struct(Box<Expr>),
148}
149
150enum SqlTemplate {
152 Literal(LitStr),
153}
154
155impl SqlTemplate {
156 fn span(&self) -> Span {
157 match self {
158 Self::Literal(lit) => lit.span(),
159 }
160 }
161
162 fn into_segments(self) -> Result<Vec<Segment>, String> {
164 match self {
165 Self::Literal(lit) => parse_literal_segments(&lit.value()),
166 }
167 }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171 if input.peek(LitStr) {
172 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173 } else {
174 Err(input.error("sql_forge!: SQL template must be a string literal"))
175 }
176}
177
178#[derive(Clone)]
180enum Segment {
181 Text(String),
183 Section { name: String },
185 Batch { parts: Vec<TextPart> },
187}
188
189#[derive(Clone)]
191enum TextPart {
192 Lit(String),
194 Param { name: String, is_list: bool },
196}
197
198enum MapKind {
204 Results,
205 Params,
206 Sections,
207}
208
209fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
215 let fork = input.fork();
216 let content;
217 syn::parenthesized!(content in fork);
218
219 if content.is_empty() {
220 return Err(input.error("sql_forge!: map argument cannot be empty"));
221 }
222
223 if content.peek(Token![>]) {
224 Ok(Some(MapKind::Results))
225 } else if content.peek(Token![:]) {
226 Ok(Some(MapKind::Params))
227 } else if content.peek(Token![#]) {
228 Ok(Some(MapKind::Sections))
229 } else {
230 Ok(None)
231 }
232}
233
234impl Parse for ResultAssign {
235 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
236 input.parse::<Token![>]>()?;
237 let name: Ident = input.parse()?;
238 input.parse::<Token![=]>()?;
239 let (force_scalar, model) = if input.peek(kw::scalar) {
240 input.parse::<kw::scalar>()?;
241 (true, input.parse::<Type>()?)
242 } else {
243 (false, input.parse::<Type>()?)
244 };
245 Ok(Self {
246 name,
247 model,
248 force_scalar,
249 })
250 }
251}
252
253fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
255 let content;
256 syn::parenthesized!(content in input);
257
258 let mut results = Vec::new();
259 while !content.is_empty() {
260 results.push(content.parse::<ResultAssign>()?);
261 if content.is_empty() {
262 break;
263 }
264 content.parse::<Token![,]>()?;
265 }
266
267 if results.is_empty() {
268 return Err(input.error("sql_forge!: result map cannot be empty"));
269 }
270
271 Ok(results)
272}
273
274fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
276 let content;
277 syn::parenthesized!(content in input);
278
279 let mut params = Vec::new();
280 while !content.is_empty() {
281 params.push(content.parse::<ParamAssign>()?);
282 if content.is_empty() {
283 break;
284 }
285 content.parse::<Token![,]>()?;
286 }
287
288 Ok(params)
289}
290
291fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
293 let content;
294 syn::parenthesized!(content in input);
295
296 let mut sections = Vec::new();
297 while !content.is_empty() {
298 sections.push(content.parse::<SectionAssign>()?);
299 if content.is_empty() {
300 break;
301 }
302 content.parse::<Token![,]>()?;
303 }
304
305 Ok(sections)
306}
307
308fn parse_params_source_expr(input: ParseStream<'_>) -> syn::Result<ParamsSource> {
310 if input.peek(syn::token::Paren) {
311 match detect_parenthesized_map_kind(input)? {
312 Some(MapKind::Results) => Err(input
313 .error("sql_forge!: result maps are only allowed as the macro result argument")),
314 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
315 Some(MapKind::Sections) => Err(input.error(
316 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317 )),
318 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319 }
320 } else {
321 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322 }
323}
324
325fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327 if input.peek(syn::token::Paren) {
328 let fork = input.fork();
329 let content;
330 syn::parenthesized!(content in fork);
331
332 if let Ok(first_expr) = content.parse::<Expr>() {
333 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334 let _ = parse_params_source_expr(&content)?;
335 if content.peek(Token![,]) {
336 content.parse::<Token![,]>()?;
337 }
338 if content.is_empty() {
339 let content;
340 syn::parenthesized!(content in input);
341 let first_expr: Expr = content.parse()?;
342 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343 input.error("sql_forge!: section tuple must start with a string literal")
344 })?;
345 let span = first_expr.span();
346 content.parse::<Token![,]>()?;
347 let params = parse_params_source_expr(&content)?;
348 if content.peek(Token![,]) {
349 content.parse::<Token![,]>()?;
350 }
351 if !content.is_empty() {
352 return Err(content.error(
353 "sql_forge!: unexpected tokens after section-local parameter source",
354 ));
355 }
356 return Ok(SectionFragment { sql, span, params });
357 }
358 }
359 }
360 }
361
362 let expr: Expr = input.parse()?;
363 let sql = extract_lit_str(&expr).ok_or_else(|| {
364 input
365 .error("sql_forge!: section values must be string literals or (string literal, params)")
366 })?;
367 Ok(SectionFragment {
368 sql,
369 span: expr.span(),
370 params: ParamsSource::None,
371 })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
376 if input.peek(Token![match]) {
377 input.parse::<Token![match]>()?;
378 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
379 let content;
380 syn::braced!(content in input);
381 let mut arms = Vec::new();
382 while !content.is_empty() {
383 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
384 let guard = if content.peek(Token![if]) {
385 content.parse::<Token![if]>()?;
386 Some(content.parse::<Expr>()?)
387 } else {
388 None
389 };
390 content.parse::<Token![=>]>()?;
391 let value = parse_section_value(&content, width)?;
393 if content.peek(Token![,]) {
394 content.parse::<Token![,]>()?;
395 }
396 arms.push(SectionMatchArm { pat, guard, value });
397 }
398 return Ok(SectionValue::Match { expr, arms });
399 }
400
401 if input.peek(Token![if]) {
402 input.parse::<Token![if]>()?;
403
404 let (pat, expr) = if input.peek(Token![let]) {
405 input.parse::<Token![let]>()?;
407 let pat: Pat = input.call(Pat::parse_single)?;
408 input.parse::<Token![=]>()?;
409 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
410 (pat, expr)
411 } else {
412 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
414 let pat: Pat = parse_quote! { true };
415 (pat, expr)
416 };
417
418 let true_content;
419 syn::braced!(true_content in input);
420 let true_value = parse_section_value(&true_content, width)?;
421 input.parse::<Token![else]>()?;
422
423 let false_value = if input.peek(Token![if]) {
424 parse_section_value(input, width)?
425 } else {
426 let false_content;
427 syn::braced!(false_content in input);
428 parse_section_value(&false_content, width)?
429 };
430
431 let wild_pat: Pat = parse_quote! { _ };
432
433 let arms = vec![
434 SectionMatchArm {
435 pat,
436 guard: None,
437 value: true_value,
438 },
439 SectionMatchArm {
440 pat: wild_pat,
441 guard: None,
442 value: false_value,
443 },
444 ];
445 return Ok(SectionValue::Match { expr, arms });
446 }
447
448 if width == 1 {
449 return Ok(SectionValue::Single(parse_section_fragment(input)?));
450 }
451
452 let content;
453 syn::parenthesized!(content in input);
454 let mut items = Vec::new();
455 while !content.is_empty() {
456 items.push(parse_section_value(&content, 1)?);
459 if content.is_empty() {
460 break;
461 }
462 content.parse::<Token![,]>()?;
463 }
464
465 if items.len() != width {
466 return Err(input.error(format!(
467 "sql_forge!: grouped section value must provide exactly {} items",
468 width,
469 )));
470 }
471
472 Ok(SectionValue::Grouped(items))
473}
474
475impl Parse for SqlForgeInput {
480 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
481 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
482 let sql = parse_sql_template(input)?;
483 (None, ResultSpec::None, false, sql)
484 } else if input.peek(kw::scalar) {
485 input.parse::<kw::scalar>()?;
486 let model: Type = input.parse()?;
487 input.parse::<Token![,]>()?;
488 let sql = parse_sql_template(input)?;
489 (None, ResultSpec::Single(Box::new(model)), true, sql)
490 } else if input.peek(syn::token::Paren) {
491 let result_map_kind = detect_parenthesized_map_kind(input)?;
492 match result_map_kind {
493 Some(MapKind::Results) => {
494 let result = ResultSpec::Group(parse_result_map(input)?);
495 input.parse::<Token![,]>()?;
496 let sql = parse_sql_template(input)?;
497 (None, result, false, sql)
498 }
499 _ => {
500 return Err(input.error(
501 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
502 ));
503 }
504 }
505 } else {
506 let first_ty: Type = input.parse()?;
507 input.parse::<Token![,]>()?;
508
509 if input.peek(LitStr) && is_db_type(&first_ty) {
510 let db = first_ty;
511 let sql = parse_sql_template(input)?;
512 (Some(db), ResultSpec::None, false, sql)
513 } else if input.peek(LitStr) {
514 let model = first_ty;
515 let sql = parse_sql_template(input)?;
516 (None, ResultSpec::Single(Box::new(model)), false, sql)
517 } else if input.peek(kw::scalar) {
518 input.parse::<kw::scalar>()?;
519 let model: Type = input.parse()?;
520 input.parse::<Token![,]>()?;
521 let sql = parse_sql_template(input)?;
522 (
523 Some(first_ty),
524 ResultSpec::Single(Box::new(model)),
525 true,
526 sql,
527 )
528 } else if input.peek(syn::token::Paren)
529 && matches!(
530 detect_parenthesized_map_kind(input)?,
531 Some(MapKind::Results)
532 )
533 {
534 let result = ResultSpec::Group(parse_result_map(input)?);
535 input.parse::<Token![,]>()?;
536 let sql = parse_sql_template(input)?;
537 (Some(first_ty), result, false, sql)
538 } else {
539 let db = Some(first_ty);
540 let model: Type = input.parse()?;
541 input.parse::<Token![,]>()?;
542 let sql = parse_sql_template(input)?;
543 (db, ResultSpec::Single(Box::new(model)), false, sql)
544 }
545 };
546
547 let mut batch = None;
548 let mut params = ParamsSource::None;
549 let mut sections = Vec::new();
550 let mut seen_params = false;
551 let mut seen_sections = false;
552
553 if input.parse::<Token![,]>().is_ok() {
554 while !input.is_empty() {
555 if input.peek(Token![..]) {
556 if batch.is_some() {
557 return Err(
558 input.error("sql_forge!: only one batch source argument is allowed")
559 );
560 }
561 input.parse::<Token![..]>()?;
562 batch = Some(input.parse::<Expr>()?);
563 } else if input.peek(syn::token::Paren) {
564 match detect_parenthesized_map_kind(input)? {
565 Some(MapKind::Results) => {
566 return Err(input.error(
567 "sql_forge!: result maps are only allowed as the macro result argument",
568 ));
569 }
570 Some(MapKind::Params) => {
571 if seen_params {
572 return Err(
573 input.error("sql_forge!: only one parameter source is allowed")
574 );
575 }
576 params = ParamsSource::Map(parse_param_map(input)?);
577 seen_params = true;
578 }
579 Some(MapKind::Sections) => {
580 if seen_sections {
581 return Err(
582 input.error("sql_forge!: duplicate section map argument")
583 );
584 }
585 sections = parse_section_map(input)?;
586 seen_sections = true;
587 }
588 None => {
589 if seen_params {
590 return Err(
591 input.error("sql_forge!: only one parameter source is allowed")
592 );
593 }
594 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
595 seen_params = true;
596 }
597 }
598 } else {
599 if seen_params {
600 return Err(input.error("sql_forge!: only one parameter source is allowed"));
601 }
602 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
603 seen_params = true;
604 }
605
606 if input.parse::<Token![,]>().is_ok() {
607 continue;
608 }
609 break;
610 }
611 }
612
613 if !input.is_empty() {
614 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
615 }
616
617 Ok(Self {
618 db,
619 result,
620 force_scalar,
621 sql,
622 params,
623 sections,
624 batch,
625 })
626 }
627}
628
629fn resolve_db_from_env() -> Result<Type, String> {
635 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
636 return syn::parse_str::<Type>(&val).map_err(|err| {
637 format!(
638 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
639 val, err
640 )
641 });
642 }
643
644 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
645 Ok(d) => d,
646 Err(_) => {
647 return Err(
648 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
649 or configure [package.metadata.sql_forge] in Cargo.toml"
650 .to_string(),
651 );
652 }
653 };
654 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
655
656 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
657 format!(
658 "sql_forge!: failed to read {}: {}",
659 manifest_path.display(),
660 err
661 )
662 })?;
663
664 let value: toml::Value = toml::from_str(&cargo_toml)
665 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
666
667 let db_str = value
668 .get("package")
669 .and_then(|v| v.get("metadata"))
670 .and_then(|v| v.get("sql_forge"))
671 .and_then(|v| v.get("db"))
672 .and_then(|v| v.as_str())
673 .ok_or({
674 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
675 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
676 })?;
677
678 syn::parse_str::<Type>(db_str).map_err(|err| {
679 format!(
680 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
681 db_str, err
682 )
683 })
684}
685
686fn uses_dollar_params(db: &Type) -> bool {
688 let Type::Path(type_path) = db else {
689 return false;
690 };
691 type_path
692 .path
693 .segments
694 .last()
695 .is_some_and(|s| s.ident == "Postgres")
696}
697
698fn is_db_type(ty: &Type) -> bool {
700 let Type::Path(type_path) = ty else {
701 return false;
702 };
703 if type_path.qself.is_some() {
704 return false;
705 }
706 let segs = &type_path.path.segments;
707 if segs.len() != 2 {
708 return false;
709 }
710 segs[0].ident == "sqlx"
711 && ["MySql", "Postgres", "Sqlite"].contains(&segs[1].ident.to_string().as_str())
712}
713
714fn is_builtin_scalar_type(ty: &Type) -> bool {
716 let Type::Path(type_path) = ty else {
717 return false;
718 };
719
720 if type_path.qself.is_some()
721 || type_path.path.leading_colon.is_some()
722 || type_path.path.segments.len() != 1
723 {
724 return false;
725 }
726
727 let ident = &type_path.path.segments[0].ident;
728 ident == "i8"
729 || ident == "i16"
730 || ident == "i32"
731 || ident == "i64"
732 || ident == "isize"
733 || ident == "u8"
734 || ident == "u16"
735 || ident == "u32"
736 || ident == "u64"
737 || ident == "usize"
738 || ident == "f32"
739 || ident == "f64"
740 || ident == "bool"
741 || ident == "String"
742}
743
744fn scalar_output_type(model: &Type) -> Option<&Type> {
746 if is_builtin_scalar_type(model) {
747 return Some(model);
748 }
749 None
750}
751
752fn push_text_segment(out: &mut Vec<Segment>, text: String) {
754 if text.is_empty() {
755 return;
756 }
757 match out.last_mut() {
758 Some(Segment::Text(existing)) => existing.push_str(&text),
759 _ => out.push(Segment::Text(text)),
760 }
761}
762
763fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
765 let mut out = Vec::new();
766 let mut text = String::new();
767 let mut chars = sql.chars().peekable();
768
769 while let Some(ch) = chars.next() {
770 if ch != '{' {
771 text.push(ch);
772 continue;
773 }
774
775 if chars.peek() == Some(&'(') {
776 push_text_segment(&mut out, std::mem::take(&mut text));
777
778 let mut paren_depth = 0u32;
779 let mut content = String::new();
780 let mut found_close = false;
781 for ch in chars.by_ref() {
782 if ch == '{' {
783 return Err(
784 "sql_forge!: nested braces not allowed inside batch section".to_string()
785 );
786 }
787 if ch == '}' {
788 if paren_depth != 0 {
789 return Err(
790 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
791 .to_string(),
792 );
793 }
794 found_close = true;
795 break;
796 }
797 if ch == '(' {
798 paren_depth += 1;
799 } else if ch == ')' {
800 if paren_depth == 0 {
801 return Err(
802 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
803 .to_string(),
804 );
805 }
806 paren_depth -= 1;
807 }
808 content.push(ch);
809 }
810 if !found_close {
811 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
812 }
813 let parts = parse_text_parts(&content);
814 for part in &parts {
815 if let TextPart::Param { is_list: true, .. } = part {
816 return Err(
817 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
818 batch sections; use plain parameters (:name) instead"
819 .to_string(),
820 );
821 }
822 }
823 out.push(Segment::Batch { parts });
824 continue;
825 }
826
827 if chars.peek() != Some(&'#') {
828 text.push(ch);
829 continue;
830 }
831
832 chars.next();
833 push_text_segment(&mut out, std::mem::take(&mut text));
834
835 let mut name = String::new();
836 loop {
837 let Some(next) = chars.next() else {
838 return Err("sql_forge!: section placeholder without closing }".to_string());
839 };
840 if next == '}' {
841 break;
842 }
843 name.push(next);
844 }
845
846 if name.is_empty() {
847 return Err("sql_forge!: empty section placeholder name".to_string());
848 }
849
850 out.push(Segment::Section { name });
851 }
852
853 push_text_segment(&mut out, text);
854 Ok(out)
855}
856
857fn is_ident_start(ch: char) -> bool {
863 ch == '_' || ch.is_ascii_alphabetic()
864}
865
866fn is_ident_continue(ch: char) -> bool {
868 is_ident_start(ch) || ch.is_ascii_digit()
869}
870
871fn sanitize_backticked_alias_ident(content: &str) -> String {
873 let mut split_at = content.len();
874 for (idx, ch) in content.char_indices() {
875 if ch == '!' || ch == '?' || ch == ':' {
876 split_at = idx;
877 break;
878 }
879 }
880
881 if split_at == content.len() {
882 return content.to_string();
883 }
884
885 let base = content[..split_at].trim_end();
886 if base.is_empty() {
887 content.to_string()
888 } else {
889 base.to_string()
890 }
891}
892
893fn sanitize_runtime_sql_text(text: &str) -> String {
895 let mut out = String::with_capacity(text.len());
896 let mut chars = text.chars().peekable();
897
898 while let Some(ch) = chars.next() {
899 if ch != '`' {
900 out.push(ch);
901 continue;
902 }
903
904 let mut content = String::new();
905 let mut closed = false;
906
907 for next in chars.by_ref() {
908 if next == '`' {
909 closed = true;
910 break;
911 }
912 content.push(next);
913 }
914
915 if closed {
916 out.push('`');
917 out.push_str(&sanitize_backticked_alias_ident(&content));
918 out.push('`');
919 } else {
920 out.push('`');
921 out.push_str(&content);
922 break;
923 }
924 }
925
926 out
927}
928
929fn parse_text_parts(text: &str) -> Vec<TextPart> {
931 let mut parts = Vec::new();
932 let mut last = 0usize;
933 let mut iter = text.char_indices().peekable();
934
935 while let Some((idx, ch)) = iter.next() {
936 if ch != ':' {
937 continue;
938 }
939
940 let Some(&(next_idx, next_ch)) = iter.peek() else {
941 continue;
942 };
943
944 if !is_ident_start(next_ch) {
945 continue;
946 }
947
948 if text[..idx].ends_with(':') {
949 continue;
950 }
951
952 if last < idx {
953 parts.push(TextPart::Lit(text[last..idx].to_string()));
954 }
955
956 iter.next();
957
958 let mut name = String::new();
959 name.push(next_ch);
960 let mut end = next_idx + next_ch.len_utf8();
961
962 while let Some(&(j, c)) = iter.peek() {
963 if is_ident_continue(c) {
964 name.push(c);
965 end = j + c.len_utf8();
966 iter.next();
967 } else {
968 break;
969 }
970 }
971
972 let mut is_list = false;
973 if text[end..].starts_with("[]") {
974 is_list = true;
975 end += 2;
976 }
977
978 parts.push(TextPart::Param { name, is_list });
979 last = end;
980 }
981
982 if last < text.len() {
983 parts.push(TextPart::Lit(text[last..].to_string()));
984 }
985
986 parts
987}
988
989#[allow(clippy::type_complexity)]
1000fn render_validator_sql(
1001 parts: &[TextPart],
1002 use_dollar_params: bool,
1003 param_offset: &mut usize,
1004 list_count: usize,
1005 batch_expr: Option<TokenStream2>,
1006) -> Result<(String, Vec<(String, bool)>, Vec<TokenStream2>), TokenStream> {
1007 let mut out_sql = String::new();
1008 let mut occurrences = Vec::new();
1009 let mut batch_args = Vec::new();
1010
1011 for part in parts {
1012 match part {
1013 TextPart::Lit(lit) => out_sql.push_str(lit),
1014 TextPart::Param { name, is_list } => {
1015 if let Some(ref batch_expr) = batch_expr {
1016 if *is_list {
1017 return Err(syn::Error::new(
1018 Span::call_site(),
1019 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
1020 batch sections; use plain parameters (:name) instead",
1021 )
1022 .to_compile_error()
1023 .into());
1024 }
1025 let field_ident = format_ident!("{}", name);
1026 if use_dollar_params {
1027 *param_offset += 1;
1028 write!(out_sql, "${}", *param_offset).unwrap();
1029 } else {
1030 out_sql.push('?');
1031 }
1032 batch_args.push(quote! { #batch_expr[0].#field_ident });
1033 } else if *is_list && list_count > 1 {
1034 let slots: Vec<String> = if use_dollar_params {
1035 (0..list_count)
1036 .map(|i| format!("${}", *param_offset + i + 1))
1037 .collect()
1038 } else {
1039 (0..list_count).map(|_| "?".to_string()).collect()
1040 };
1041 if use_dollar_params {
1042 *param_offset += list_count;
1043 }
1044 out_sql.push_str(&slots.join(", "));
1045 occurrences.push((name.clone(), *is_list));
1046 } else {
1047 if use_dollar_params {
1048 *param_offset += 1;
1049 write!(out_sql, "${}", *param_offset).unwrap();
1050 } else {
1051 out_sql.push('?');
1052 }
1053 occurrences.push((name.clone(), *is_list));
1054 }
1055 }
1056 }
1057 }
1058
1059 Ok((out_sql, occurrences, batch_args))
1060}
1061
1062fn strip_expr(expr: &Expr) -> &Expr {
1085 match expr {
1086 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
1087 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
1088 Expr::Block(ExprBlock { block, .. }) => {
1089 if block.stmts.len() != 1 {
1090 return expr;
1091 }
1092 match &block.stmts[0] {
1093 Stmt::Expr(inner, None) => strip_expr(inner),
1094 _ => expr,
1095 }
1096 }
1097 _ => expr,
1098 }
1099}
1100
1101fn extract_lit_str(expr: &Expr) -> Option<String> {
1103 match strip_expr(expr) {
1104 Expr::Lit(ExprLit {
1105 lit: Lit::Str(lit), ..
1106 }) => Some(lit.value()),
1107 _ => None,
1108 }
1109}
1110
1111fn result_flag_ident(name: &str) -> syn::Ident {
1117 format_ident!("__sql_forge_result_flag_{}", name)
1118}
1119
1120fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
1124 fn walk(stream: TokenStream2) -> TokenStream2 {
1125 let mut out = TokenStream2::new();
1126 let iter = stream.into_iter().peekable();
1127
1128 for token in iter {
1129 match token {
1130 TokenTree::Group(group) => {
1131 if group.delimiter() == Delimiter::Brace {
1132 let mut inner = group.stream().into_iter();
1133 let first = inner.next();
1134 let second = inner.next();
1135 let third = inner.next();
1136
1137 if let (
1138 Some(TokenTree::Punct(p)),
1139 Some(TokenTree::Ident(name_ident)),
1140 None,
1141 ) = (first, second, third)
1142 {
1143 if p.as_char() == '>' {
1144 let ident = result_flag_ident(&name_ident.to_string());
1145 out.extend(std::iter::once(TokenTree::Ident(ident)));
1146 continue;
1147 }
1148 }
1149 }
1150
1151 let new_inner = walk(group.stream());
1152 let mut new_group = Group::new(group.delimiter(), new_inner);
1153 new_group.set_span(group.span());
1154 out.extend(std::iter::once(TokenTree::Group(new_group)));
1155 }
1156 other => out.extend(std::iter::once(other)),
1157 }
1158 }
1159
1160 out
1161 }
1162
1163 walk(input)
1164}
1165
1166fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1168 keys.iter()
1169 .map(|key| {
1170 let ident = result_flag_ident(key);
1171 let enabled = Some(key.as_str()) == active_key;
1172 quote! { let #ident: bool = #enabled; }
1173 })
1174 .collect()
1175}
1176
1177fn transpose_section_case_matrix(
1179 case_matrix: Vec<Vec<SectionFragment>>,
1180 width: usize,
1181) -> Result<Vec<Vec<SectionFragment>>, String> {
1182 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1183
1184 for row in case_matrix {
1185 if row.len() != width {
1186 return Err(
1187 "sql_forge!: grouped sections must return one item per section".to_string(),
1188 );
1189 }
1190 for (section_idx, fragment) in row.into_iter().enumerate() {
1191 per_section[section_idx].push(fragment);
1192 }
1193 }
1194
1195 Ok(per_section)
1196}
1197
1198fn collect_section_case_matrix(
1200 value: SectionValue,
1201 width: usize,
1202 active_key: Option<&str>,
1203) -> Result<Vec<Vec<SectionFragment>>, String> {
1204 match value {
1205 SectionValue::Single(fragment) => {
1206 if width != 1 {
1207 return Err(
1208 "sql_forge!: grouped sections must return one item per section".to_string(),
1209 );
1210 }
1211 Ok(vec![vec![fragment]])
1212 }
1213 SectionValue::Grouped(values) => {
1214 if values.len() != width {
1215 return Err(
1216 "sql_forge!: grouped sections must return one item per section".to_string(),
1217 );
1218 }
1219
1220 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1221 let mut nmax = 1usize;
1222
1223 for value in values {
1224 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1225 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1226 for mut row in item_matrix {
1227 let fragment = row.pop().ok_or_else(|| {
1228 "sql_forge!: grouped sections must return one item per section".to_string()
1229 })?;
1230 if !row.is_empty() {
1231 return Err(
1232 "sql_forge!: grouped sections must return one item per section"
1233 .to_string(),
1234 );
1235 }
1236 item_variants.push(fragment);
1237 }
1238 if item_variants.is_empty() {
1239 return Err("sql_forge!: section match must have at least one arm".to_string());
1240 }
1241 nmax = nmax.max(item_variants.len());
1242 variants_by_section.push(item_variants);
1243 }
1244
1245 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1246 for case_idx in 0..nmax {
1247 let mut row = Vec::<SectionFragment>::with_capacity(width);
1248 for variants in &variants_by_section {
1249 row.push(variants[case_idx % variants.len()].clone());
1250 }
1251 case_matrix.push(row);
1252 }
1253
1254 Ok(case_matrix)
1255 }
1256 SectionValue::Match { expr, arms } => {
1257 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1258
1259 if let Some(key) = expr_result_flag_key(&expr) {
1260 let target = active_key == Some(key.as_str());
1261 for arm in arms {
1262 if arm.guard.is_none() {
1263 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1264 continue;
1265 }
1266 }
1267 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1268 wrap_section_case_matrix_for_match_arm(
1269 &mut arm_cases,
1270 &expr,
1271 &arm.pat,
1272 arm.guard.as_ref(),
1273 );
1274 case_matrix.extend(arm_cases);
1275 }
1276 } else {
1277 for arm in arms {
1278 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1279 wrap_section_case_matrix_for_match_arm(
1280 &mut arm_cases,
1281 &expr,
1282 &arm.pat,
1283 arm.guard.as_ref(),
1284 );
1285 case_matrix.extend(arm_cases);
1286 }
1287 }
1288
1289 if case_matrix.is_empty() {
1290 return Err("sql_forge!: section match must have at least one arm".to_string());
1291 }
1292
1293 Ok(case_matrix)
1294 }
1295 }
1296}
1297
1298fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1301 let match_expr = match_expr.clone();
1302 let pat = pat.clone();
1303 let pattern_binds_values = match &pat {
1304 Pat::Ident(_) => true,
1305 Pat::Or(pat_or) => pat_or
1306 .cases
1307 .iter()
1308 .any(|case| matches!(case, Pat::Ident(_))),
1309 Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1310 Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1311 Pat::Slice(pat_slice) => pat_slice
1312 .elems
1313 .iter()
1314 .any(|elem| matches!(elem, Pat::Ident(_))),
1315 Pat::Struct(pat_struct) => pat_struct
1316 .fields
1317 .iter()
1318 .any(|field| matches!(*field.pat, Pat::Ident(_))),
1319 Pat::Tuple(pat_tuple) => pat_tuple
1320 .elems
1321 .iter()
1322 .any(|elem| matches!(elem, Pat::Ident(_))),
1323 Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1324 .elems
1325 .iter()
1326 .any(|elem| matches!(elem, Pat::Ident(_))),
1327 Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1328 _ => false,
1329 };
1330
1331 if pattern_binds_values {
1332 let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1333 .into_iter()
1334 .map(|ident| quote! { let _ = &#ident; })
1335 .collect();
1336 if let Some(guard) = guard.cloned() {
1337 parse_quote! {
1338 match &(#match_expr) {
1339 #pat if #guard => { #( #pat_refs )* #expr },
1340 _ => unreachable!("sql_forge!: validator arm mismatch"),
1341 }
1342 }
1343 } else {
1344 parse_quote! {
1345 match &(#match_expr) {
1346 #pat => { #( #pat_refs )* #expr },
1347 _ => unreachable!("sql_forge!: validator arm mismatch"),
1348 }
1349 }
1350 }
1351 } else if let Some(guard) = guard.cloned() {
1352 parse_quote! {
1353 match &(#match_expr) {
1354 #pat if #guard => { &(#expr) },
1355 _ => unreachable!("sql_forge!: validator arm mismatch"),
1356 }
1357 }
1358 } else {
1359 parse_quote! {
1360 match &(#match_expr) {
1361 #pat => { &(#expr) },
1362 _ => unreachable!("sql_forge!: validator arm mismatch"),
1363 }
1364 }
1365 }
1366}
1367
1368fn wrap_params_source_for_match_arm(
1370 params: &mut ParamsSource,
1371 match_expr: &Expr,
1372 pat: &Pat,
1373 guard: Option<&Expr>,
1374) {
1375 match params {
1376 ParamsSource::None => {}
1377 ParamsSource::Map(entries) => {
1378 for entry in entries {
1379 entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1380 }
1381 }
1382 ParamsSource::Struct(expr) => {
1383 **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1384 }
1385 }
1386}
1387
1388fn wrap_section_case_matrix_for_match_arm(
1390 case_matrix: &mut [Vec<SectionFragment>],
1391 match_expr: &Expr,
1392 pat: &Pat,
1393 guard: Option<&Expr>,
1394) {
1395 for row in case_matrix {
1396 for fragment in row {
1397 wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1398 }
1399 }
1400}
1401
1402fn collect_section_variants(
1412 value: SectionValue,
1413 width: usize,
1414) -> Result<Vec<Vec<SectionFragment>>, String> {
1415 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1416}
1417
1418fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1420 match strip_expr(expr) {
1421 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1422 let name = path.path.segments[0].ident.to_string();
1423 name.strip_prefix("__sql_forge_result_flag_")
1424 .map(|v| v.to_string())
1425 }
1426 _ => None,
1427 }
1428}
1429
1430fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1432 match pat {
1433 Pat::Lit(expr_lit) => match &expr_lit.lit {
1434 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1435 _ => None,
1436 },
1437 Pat::Wild(_) => Some(true),
1438 _ => None,
1439 }
1440}
1441
1442fn collect_section_variants_for_result(
1447 value: SectionValue,
1448 width: usize,
1449 active_key: Option<&str>,
1450) -> Result<Vec<Vec<SectionFragment>>, String> {
1451 transpose_section_case_matrix(
1452 collect_section_case_matrix(value, width, active_key)?,
1453 width,
1454 )
1455}
1456
1457fn build_param_bindings(
1465 params: &ParamsSource,
1466 used_param_names: &[String],
1467 prefix: &str,
1468 for_validator: bool,
1469 enforce_usage_check: bool,
1470) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1471 let mut declared_params = HashMap::<String, syn::Ident>::new();
1472 let mut bindings = Vec::<TokenStream2>::new();
1473
1474 match params {
1475 ParamsSource::None => {}
1476 ParamsSource::Map(entries) => {
1477 for entry in entries {
1478 let key = entry.name.to_string();
1479 if declared_params.contains_key(&key) {
1480 return Err(syn::Error::new(
1481 entry.name.span(),
1482 "sql_forge!: duplicated parameter mapping",
1483 )
1484 .to_compile_error()
1485 .into());
1486 }
1487 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1488 return Err(syn::Error::new(
1489 entry.name.span(),
1490 format!(
1491 "sql_forge!: parameter :{} is unused in the SQL template",
1492 key,
1493 ),
1494 )
1495 .to_compile_error()
1496 .into());
1497 }
1498 let local_ident = format_ident!("__sql_forge_{}_{}", prefix, key);
1499 let expr = &entry.expr;
1500 if for_validator {
1501 bindings.push(quote! {
1502 let #local_ident = &(#expr);
1503 });
1504 } else {
1505 bindings.push(quote! {
1506 let #local_ident = #expr;
1507 });
1508 }
1509 declared_params.insert(key, local_ident);
1510 }
1511 }
1512 ParamsSource::Struct(expr) => {
1513 let source_ident = format_ident!("__sql_forge_source_{}", prefix);
1514 bindings.push(quote! {
1515 let #source_ident = &(#expr);
1516 });
1517 for name in used_param_names {
1518 let local_ident = format_ident!("__sql_forge_{}_{}", prefix, name);
1519 let field_ident = format_ident!("{}", name);
1520 if for_validator {
1521 bindings.push(quote! {
1522 let #local_ident = &#source_ident.#field_ident;
1523 });
1524 } else {
1525 bindings.push(quote! {
1526 let #local_ident = #source_ident.#field_ident;
1527 });
1528 }
1529 declared_params.insert(name.to_string(), local_ident);
1530 }
1531 }
1532 }
1533
1534 Ok((declared_params, bindings))
1535}
1536
1537struct ValidatorRenderContext<'a> {
1538 params: &'a HashMap<String, syn::Ident>,
1539 use_dollar_params: bool,
1540 sql_span: Span,
1541 list_count: usize,
1542}
1543
1544fn render_validator_args(
1549 sql: &str,
1550 param_offset: &mut usize,
1551 arg_index: &mut usize,
1552 context: &ValidatorRenderContext<'_>,
1553) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1554 let parts = parse_text_parts(sql);
1555 let (rendered_sql, occurrences, _batch_args) = render_validator_sql(
1556 &parts,
1557 context.use_dollar_params,
1558 param_offset,
1559 context.list_count,
1560 None,
1561 )?;
1562
1563 let mut setup = Vec::<TokenStream2>::new();
1564 let mut args = Vec::<TokenStream2>::new();
1565
1566 for (name, is_list) in occurrences {
1567 let Some(local_ident) = context.params.get(&name) else {
1568 return Err(syn::Error::new(
1569 context.sql_span,
1570 format!("sql_forge!: parameter :{} has no mapping", name),
1571 )
1572 .to_compile_error()
1573 .into());
1574 };
1575
1576 if is_list {
1577 for _ in 0..context.list_count {
1578 let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1579 *arg_index += 1;
1580 if context.use_dollar_params {
1581 setup.push(quote! {
1582 let #value_ident = sql_forge::sql_forge_validator_value(
1583 (#local_ident)
1584 .as_slice()
1585 .first()
1586 .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1587 );
1588 });
1589 } else {
1590 setup.push(quote! {
1591 let #value_ident = (#local_ident)
1592 .as_slice()
1593 .first()
1594 .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1595 });
1596 }
1597 args.push(quote! { #value_ident });
1598 }
1599 } else {
1600 let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1601 *arg_index += 1;
1602 if context.use_dollar_params {
1603 setup.push(quote! {
1604 let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1605 });
1606 } else {
1607 setup.push(quote! {
1608 let #value_ident = #local_ident;
1609 });
1610 }
1611 args.push(quote! { #value_ident });
1612 }
1613 }
1614
1615 Ok((rendered_sql, setup, args))
1616}
1617
1618fn render_runtime_fragment(
1625 fragment: &SectionFragment,
1626 local_params: &HashMap<String, syn::Ident>,
1627) -> Result<TokenStream2, TokenStream> {
1628 let mut steps = Vec::<TokenStream2>::new();
1629
1630 for part in parse_text_parts(&fragment.sql) {
1631 match part {
1632 TextPart::Lit(lit) => {
1633 let lit_str = LitStr::new(&lit, fragment.span);
1634 steps.push(quote! { __builder.push(#lit_str); });
1635 }
1636 TextPart::Param { name, is_list } => {
1637 let Some(local_ident) = local_params.get(&name) else {
1638 return Err(syn::Error::new(
1639 fragment.span,
1640 format!("sql_forge!: parameter :{} has no mapping", name),
1641 )
1642 .to_compile_error()
1643 .into());
1644 };
1645
1646 if is_list {
1647 steps.push(quote! {
1648 let __sql_forge_values = #local_ident;
1649 let mut __separated = __builder.separated(", ");
1650 for __value in __sql_forge_values {
1651 __separated.push_bind(__value);
1652 }
1653 });
1654 } else {
1655 steps.push(quote! {
1656 __builder.push_bind(#local_ident);
1657 });
1658 }
1659 }
1660 }
1661 }
1662
1663 Ok(quote! { #( #steps )* })
1664}
1665
1666fn is_pat_binding(ident: &Ident) -> bool {
1668 let name = ident.to_string();
1669 !name.is_empty()
1670 && name
1671 .chars()
1672 .next()
1673 .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1674}
1675
1676fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1678 let mut names = Vec::new();
1679 fn walk(p: &Pat, names: &mut Vec<Ident>) {
1680 match p {
1681 Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1682 Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1683 Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1684 Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1685 Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1686 Pat::Paren(pp) => walk(&pp.pat, names),
1687 Pat::Reference(pr) => walk(&pr.pat, names),
1688 Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1689 Pat::Type(pt) => walk(&pt.pat, names),
1690 _ => {}
1691 }
1692 }
1693 walk(pat, &mut names);
1694 names
1695}
1696
1697fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1699 match value {
1700 SectionValue::Single(f) => {
1701 if collect_used_param_names_in_sql(&f.sql)
1702 .iter()
1703 .any(|n| n == name)
1704 {
1705 return true;
1706 }
1707 if let ParamsSource::Map(entries) = &f.params {
1708 for e in entries {
1709 let expr = &e.expr;
1710 let expr_str = quote! { #expr }.to_string();
1711 if expr_str.trim() == name {
1712 return true;
1713 }
1714 }
1715 }
1716 false
1717 }
1718 SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1719 SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1720 let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1721 .into_iter()
1722 .map(|i| i.to_string())
1723 .collect();
1724 if pat_vars.contains(name) {
1725 false
1726 } else {
1727 section_value_refers_to(&arm.value, name)
1728 }
1729 }),
1730 }
1731}
1732
1733fn build_section_runtime_action(
1735 value: &SectionValue,
1736 section_idx: usize,
1737 prefix: &str,
1738) -> Result<TokenStream2, TokenStream> {
1739 match value {
1740 SectionValue::Single(fragment) => {
1741 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1742 let (local_params, bindings) =
1743 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1744 let body = render_runtime_fragment(fragment, &local_params)?;
1745 Ok(quote! {{ #( #bindings )* #body }})
1746 }
1747 SectionValue::Grouped(fragments) => build_section_runtime_action(
1748 &fragments[section_idx],
1749 0,
1750 &format!("{}_grouped_{}", prefix, section_idx),
1751 ),
1752 SectionValue::Match { expr, arms } => {
1753 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1754 .iter()
1755 .enumerate()
1756 .map(|(arm_idx, arm)| {
1757 let pat = &arm.pat;
1758 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1759 let body = build_section_runtime_action(
1760 &arm.value,
1761 section_idx,
1762 &format!("{}_{}", prefix, arm_idx),
1763 )?;
1764 let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1765 .into_iter()
1766 .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1767 .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1768 .collect();
1769 Ok::<TokenStream2, TokenStream>(quote! {
1770 #pat #guard_tokens => {
1771 #( #noop_refs )*
1772 #body
1773 }
1774 })
1775 })
1776 .collect();
1777 let arm_tokens = arm_tokens?;
1778 Ok(quote! {
1779 match #expr {
1780 #( #arm_tokens ),*
1781 }
1782 })
1783 }
1784 }
1785}
1786
1787fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1789 let mut names = Vec::new();
1790 let mut seen = HashSet::<String>::new();
1791
1792 for segment in segments {
1793 match segment {
1794 Segment::Text(text) => {
1795 for name in collect_used_param_names_in_sql(text) {
1796 if seen.insert(name.clone()) {
1797 names.push(name);
1798 }
1799 }
1800 }
1801 Segment::Batch { parts } => {
1802 for part in parts {
1803 if let TextPart::Param { name, .. } = part {
1804 if seen.insert(name.clone()) {
1805 names.push(name.clone());
1806 }
1807 }
1808 }
1809 }
1810 _ => {}
1811 }
1812 }
1813
1814 names
1815}
1816
1817fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1819 let mut names = Vec::new();
1820 let mut seen = HashSet::<String>::new();
1821 for part in parse_text_parts(sql) {
1822 if let TextPart::Param { name, .. } = part {
1823 if seen.insert(name.to_string()) {
1824 names.push(name);
1825 }
1826 }
1827 }
1828 names
1829}
1830
1831#[proc_macro]
2057#[allow(clippy::too_many_lines)]
2058pub fn sql_forge(input: TokenStream) -> TokenStream {
2059 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
2061 let SqlForgeInput {
2062 db,
2063 result,
2064 force_scalar,
2065 sql,
2066 params,
2067 sections,
2068 batch,
2069 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
2070 Ok(v) => v,
2071 Err(err) => return err.to_compile_error().into(),
2072 };
2073
2074 let db = match db {
2076 Some(db) => db,
2077 None => match resolve_db_from_env() {
2078 Ok(db) => db,
2079 Err(msg) => {
2080 return syn::Error::new(Span::call_site(), msg)
2081 .to_compile_error()
2082 .into();
2083 }
2084 },
2085 };
2086
2087 let use_dollar_params = uses_dollar_params(&db);
2088 let list_count: usize = 3;
2089
2090 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
2094 ResultSpec::None => {
2095 vec![(None, None, None)]
2096 }
2097 ResultSpec::Single(ref model) => {
2098 let model_ty = (**model).clone();
2099 let scalar = if force_scalar {
2100 Some(model_ty.clone())
2101 } else {
2102 scalar_output_type(model.as_ref()).cloned()
2103 };
2104 vec![(None, Some(model_ty), scalar)]
2105 }
2106 ResultSpec::Group(ref cases) => {
2107 if force_scalar {
2108 return syn::Error::new(
2109 Span::call_site(),
2110 "sql_forge!: scalar mode is not supported for grouped result maps",
2111 )
2112 .to_compile_error()
2113 .into();
2114 }
2115
2116 let mut out = Vec::new();
2117 let mut seen = HashSet::new();
2118 for case in cases {
2119 let key = case.name.to_string();
2120 if !seen.insert(key.clone()) {
2121 return syn::Error::new(
2122 case.name.span(),
2123 "sql_forge!: duplicated key in result map",
2124 )
2125 .to_compile_error()
2126 .into();
2127 }
2128
2129 let model = case.model.clone();
2130 let scalar = if case.force_scalar {
2131 Some(model.clone())
2132 } else {
2133 scalar_output_type(&case.model).cloned()
2134 };
2135 out.push((Some(key), Some(model), scalar));
2136 }
2137 out
2138 }
2139 };
2140 let group_result_keys: Vec<String> = result_cases
2141 .iter()
2142 .filter_map(|(key, _, _)| key.as_ref().cloned())
2143 .collect();
2144 let is_grouped_result = !group_result_keys.is_empty();
2145 let sql_span = sql.span();
2146
2147 let segments = match sql.into_segments() {
2149 Ok(segments) => segments,
2150 Err(msg) => {
2151 return syn::Error::new(sql_span, msg).to_compile_error().into();
2152 }
2153 };
2154
2155 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2156 match (&batch, has_batch_segment) {
2157 (None, true) => {
2158 return syn::Error::new(
2159 sql_span,
2160 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2161 was provided"
2162 )
2163 .to_compile_error()
2164 .into();
2165 }
2166 (Some(_), false) => {
2167 return syn::Error::new(
2168 sql_span,
2169 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2170 batch section",
2171 )
2172 .to_compile_error()
2173 .into();
2174 }
2175 _ => {}
2176 }
2177
2178 let used_param_names = collect_used_param_names(&segments);
2179
2180 let text_param_names: std::collections::HashSet<String> = segments
2188 .iter()
2189 .filter_map(|s| {
2190 if let Segment::Text(text) = s {
2191 Some(collect_used_param_names_in_sql(text).into_iter())
2192 } else {
2193 None
2194 }
2195 })
2196 .flatten()
2197 .collect();
2198 let top_level_used_names: Vec<String> = used_param_names
2199 .iter()
2200 .filter(|n| text_param_names.contains(*n))
2201 .cloned()
2202 .collect();
2203
2204 let (declared_params, validator_param_bindings) =
2206 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
2207 Ok(v) => v,
2208 Err(err) => return err,
2209 };
2210
2211 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2212
2213 for assign in §ions {
2215 let SectionAssign { names, value } = assign;
2216
2217 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2219 for (section_idx, name_ident) in names.iter().enumerate() {
2220 let name = name_ident.to_string();
2221 if runtime_section_actions.contains_key(&name) {
2222 return syn::Error::new(
2223 name_ident.span(),
2224 "sql_forge!: duplicated section mapping",
2225 )
2226 .to_compile_error()
2227 .into();
2228 }
2229 let action = match build_section_runtime_action(
2230 value,
2231 section_idx,
2232 &format!("section_{}", name),
2233 ) {
2234 Ok(action) => action,
2235 Err(err) => return err,
2236 };
2237 named_actions.push((name, action));
2238 }
2239
2240 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2242 return syn::Error::new(names[0].span(), msg)
2243 .to_compile_error()
2244 .into();
2245 }
2246
2247 for (name, action) in named_actions {
2248 runtime_section_actions.insert(name, action);
2249 }
2250 }
2251
2252 let sql_section_names: std::collections::HashSet<&str> = segments
2253 .iter()
2254 .filter_map(|seg| {
2255 if let Segment::Section { name } = seg {
2256 Some(name.as_str())
2257 } else {
2258 None
2259 }
2260 })
2261 .collect();
2262 for name in runtime_section_actions.keys() {
2263 if !sql_section_names.contains(name.as_str()) {
2264 return syn::Error::new(
2265 sql_span,
2266 format!(
2267 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2268 name, name,
2269 ),
2270 )
2271 .to_compile_error()
2272 .into();
2273 }
2274 }
2275
2276 let mut generated_query_defs = Vec::<TokenStream2>::new();
2278 let mut generated_query_values = Vec::<TokenStream2>::new();
2279 let mut group_field_defs = Vec::<TokenStream2>::new();
2280 let mut group_field_idents = Vec::<syn::Ident>::new();
2281 let mut group_field_tys = Vec::<TokenStream2>::new();
2282 let mut group_trait_impls = Vec::<TokenStream2>::new();
2283
2284 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2285
2286 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2287 let suffix = result_key.as_deref().unwrap_or("single");
2288 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2289 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2290
2291 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2292
2293 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2294 for assign in §ions {
2295 let SectionAssign { names, value } = assign;
2296 let variants_by_section = match collect_section_variants_for_result(
2297 value.clone(),
2298 names.len(),
2299 result_key.as_deref(),
2300 ) {
2301 Ok(v) => v,
2302 Err(msg) => {
2303 return syn::Error::new(names[0].span(), msg)
2304 .to_compile_error()
2305 .into();
2306 }
2307 };
2308
2309 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2310 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2311 }
2312 }
2313
2314 let mut nmax = 1usize;
2315 for segment in &segments {
2316 if let Segment::Section { name } = segment {
2317 if let Some(variants) = section_variants_for_validation.get(name) {
2318 if variants.is_empty() {
2319 return syn::Error::new(
2320 sql_span,
2321 format!("sql_forge!: section {{#{}}} has no possible variants", name),
2322 )
2323 .to_compile_error()
2324 .into();
2325 }
2326 nmax = nmax.max(variants.len());
2327 } else {
2328 return syn::Error::new(
2329 sql_span,
2330 format!("sql_forge!: section {{#{}}} has no mapping", name),
2331 )
2332 .to_compile_error()
2333 .into();
2334 }
2335 }
2336 }
2337
2338 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2339 for case_idx in 0..nmax {
2340 let mut sql_case = String::new();
2341 let mut case_setup = Vec::<TokenStream2>::new();
2342 let mut case_args = Vec::<TokenStream2>::new();
2343 let mut param_offset = 0usize;
2344 let mut arg_index = 0usize;
2345 let root_validator_context = ValidatorRenderContext {
2346 params: &declared_params,
2347 use_dollar_params,
2348 sql_span,
2349 list_count,
2350 };
2351
2352 for segment in &segments {
2353 match segment {
2354 Segment::Text(text) => {
2355 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2356 text,
2357 &mut param_offset,
2358 &mut arg_index,
2359 &root_validator_context,
2360 ) {
2361 Ok(value) => value,
2362 Err(err) => return err,
2363 };
2364 sql_case.push_str(&chunk_sql);
2365 case_setup.extend(chunk_setup);
2366 case_args.extend(chunk_args);
2367 }
2368 Segment::Section { name } => {
2369 let Some(variants) = section_variants_for_validation.get(name) else {
2370 return syn::Error::new(
2371 sql_span,
2372 format!("sql_forge!: section {{#{}}} has no mapping", name),
2373 )
2374 .to_compile_error()
2375 .into();
2376 };
2377
2378 let fragment = &variants[case_idx % variants.len()];
2379 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2380 let (local_params, bindings) = match build_param_bindings(
2381 &fragment.params,
2382 &used_param_names,
2383 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2384 true,
2385 true,
2386 ) {
2387 Ok(value) => value,
2388 Err(err) => return err,
2389 };
2390 let section_validator_context = ValidatorRenderContext {
2391 params: &local_params,
2392 use_dollar_params,
2393 sql_span: fragment.span,
2394 list_count,
2395 };
2396 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2397 &fragment.sql,
2398 &mut param_offset,
2399 &mut arg_index,
2400 §ion_validator_context,
2401 ) {
2402 Ok(value) => value,
2403 Err(err) => return err,
2404 };
2405 sql_case.push_str(&chunk_sql);
2406 case_setup.extend(bindings);
2407 case_setup.extend(chunk_setup);
2408 case_args.extend(chunk_args);
2409 }
2410 Segment::Batch { parts } => {
2411 let batch_ts = batch.as_ref().map(|e| quote! { #e });
2412 let mut first = true;
2413 for _ in 0..list_count {
2414 let sep = if first { "" } else { ", " };
2415 first = false;
2416 sql_case.push_str(sep);
2417 let (chunk_sql, _occurrences, chunk_args) = match render_validator_sql(
2418 parts,
2419 use_dollar_params,
2420 &mut param_offset,
2421 list_count,
2422 batch_ts.clone(),
2423 ) {
2424 Ok(value) => value,
2425 Err(err) => return err,
2426 };
2427 sql_case.push_str(&chunk_sql);
2428 case_args.extend(chunk_args);
2429 }
2430 }
2431 }
2432 }
2433
2434 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2435 }
2436
2437 let mut validator_invocations = Vec::<TokenStream2>::new();
2438 for (sql_lit, case_setup, args) in &validator_cases {
2439 if model_opt.is_none() {
2440 if args.is_empty() {
2441 validator_invocations.push(quote! {
2442 {
2443 #( #case_setup )*
2444 let _ = sqlx::query_scalar!(
2445 #sql_lit,
2446 );
2447 }
2448 });
2449 } else {
2450 validator_invocations.push(quote! {
2451 {
2452 #( #case_setup )*
2453 let _ = sqlx::query_scalar!(
2454 #sql_lit,
2455 #( #args ),*
2456 );
2457 }
2458 });
2459 }
2460 } else if let Some(scalar_ty) = scalar_model_ty {
2461 if args.is_empty() {
2462 validator_invocations.push(quote! {
2463 {
2464 #( #case_setup )*
2465 let _ = sqlx::query_scalar!(
2466 #sql_lit,
2467 );
2468 }
2469 });
2470 } else {
2471 validator_invocations.push(quote! {
2472 {
2473 #( #case_setup )*
2474 let _ = sqlx::query_scalar!(
2475 #sql_lit,
2476 #( #args ),*
2477 );
2478 }
2479 });
2480 }
2481 let _ = scalar_ty;
2482 } else if args.is_empty() {
2483 validator_invocations.push(quote! {
2484 {
2485 #( #case_setup )*
2486 let _ = sqlx::query_as!(
2487 __SqlForgeModel,
2488 #sql_lit,
2489 );
2490 }
2491 });
2492 } else {
2493 validator_invocations.push(quote! {
2494 {
2495 #( #case_setup )*
2496 let _ = sqlx::query_as!(
2497 __SqlForgeModel,
2498 #sql_lit,
2499 #( #args ),*
2500 );
2501 }
2502 });
2503 }
2504 }
2505
2506 let model_alias = if let Some(model) = model_opt {
2507 if scalar_model_ty.is_none() {
2508 quote! { type __SqlForgeModel = #model; }
2509 } else {
2510 quote! {}
2511 }
2512 } else {
2513 quote! {}
2514 };
2515 grouped_validator_invocations.push(quote! {
2516 {
2517 #( #flag_bindings )*
2518 #model_alias
2519 #( #validator_invocations )*
2520 }
2521 });
2522
2523 let (runtime_declared_params, runtime_param_bindings) =
2524 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2525 Ok(v) => v,
2526 Err(err) => return err,
2527 };
2528
2529 let mut runtime_steps = Vec::<TokenStream2>::new();
2530 for (seg_idx, segment) in segments.iter().enumerate() {
2531 match segment {
2532 Segment::Text(text) => {
2533 for part in parse_text_parts(text) {
2534 match part {
2535 TextPart::Lit(lit) => {
2536 let lit = sanitize_runtime_sql_text(&lit);
2537 let lit_str = LitStr::new(&lit, sql_span);
2538 runtime_steps.push(quote! {
2539 __builder.push(#lit_str);
2540 });
2541 }
2542 TextPart::Param { name, is_list } => {
2543 let Some(local_ident) = runtime_declared_params.get(&name) else {
2544 return syn::Error::new(
2545 sql_span,
2546 format!("sql_forge!: parameter :{} has no mapping", name),
2547 )
2548 .to_compile_error()
2549 .into();
2550 };
2551
2552 if is_list {
2553 runtime_steps.push(quote! {
2554 let __sql_forge_values = #local_ident;
2555 let mut __separated = __builder.separated(", ");
2556 for __value in __sql_forge_values {
2557 __separated.push_bind(__value);
2558 }
2559 });
2560 } else {
2561 runtime_steps.push(quote! {
2562 __builder.push_bind(#local_ident);
2563 });
2564 }
2565 }
2566 }
2567 }
2568 }
2569 Segment::Section { name } => {
2570 let Some(section_action) = runtime_section_actions.get(name) else {
2571 let _ = seg_idx;
2572 return syn::Error::new(
2573 sql_span,
2574 format!("sql_forge!: section {{#{}}} has no mapping", name),
2575 )
2576 .to_compile_error()
2577 .into();
2578 };
2579 runtime_steps.push(quote! {
2580 #section_action
2581 });
2582 }
2583 Segment::Batch { parts } => {
2584 if let Some(batch_expr) = &batch {
2585 let mut body = Vec::<TokenStream2>::new();
2586 for part in parts {
2587 match part {
2588 TextPart::Lit(lit) => {
2589 let lit_str = LitStr::new(lit, sql_span);
2590 body.push(quote! {
2591 __builder.push(#lit_str);
2592 });
2593 }
2594 TextPart::Param { name, .. } => {
2595 let field_ident = format_ident!("{}", name);
2596 body.push(quote! {
2597 __builder.push_bind(__item.#field_ident);
2598 });
2599 }
2600 }
2601 }
2602 runtime_steps.push(quote! {
2603 {
2604 let mut __first = true;
2605 for __item in #batch_expr {
2606 if !__first {
2607 __builder.push(", ");
2608 }
2609 __first = false;
2610 #( #body )*
2611 }
2612 }
2613 });
2614 }
2615 }
2616 }
2617 }
2618
2619 let exec_methods = if model_opt.is_none() {
2620 quote! {
2621 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2622 where
2623 E: sqlx::Executor<'e, Database = #db>,
2624 {
2625 self.inner.build().execute(executor).await
2626 }
2627 }
2628 } else if let Some(scalar_ty) = scalar_model_ty {
2629 quote! {
2630 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2631 where
2632 E: sqlx::Executor<'e, Database = #db>,
2633 {
2634 self.inner
2635 .build_query_scalar::<#scalar_ty>()
2636 .fetch_all(executor)
2637 .await
2638 }
2639
2640 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2641 where
2642 E: sqlx::Executor<'e, Database = #db>,
2643 {
2644 self.inner
2645 .build_query_scalar::<#scalar_ty>()
2646 .fetch_one(executor)
2647 .await
2648 }
2649
2650 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2651 where
2652 E: sqlx::Executor<'e, Database = #db>,
2653 {
2654 self.inner
2655 .build_query_scalar::<#scalar_ty>()
2656 .fetch_optional(executor)
2657 .await
2658 }
2659
2660 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2661 where
2662 E: sqlx::Executor<'e, Database = #db>,
2663 {
2664 self.inner.build().execute(executor).await
2665 }
2666 }
2667 } else {
2668 let model = model_opt.as_ref().unwrap();
2669 quote! {
2670 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2671 where
2672 E: sqlx::Executor<'e, Database = #db>,
2673 {
2674 self.inner.build_query_as::<#model>().fetch_all(executor).await
2675 }
2676
2677 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2678 where
2679 E: sqlx::Executor<'e, Database = #db>,
2680 {
2681 self.inner.build_query_as::<#model>().fetch_one(executor).await
2682 }
2683
2684 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2685 where
2686 E: sqlx::Executor<'e, Database = #db>,
2687 {
2688 self.inner
2689 .build_query_as::<#model>()
2690 .fetch_optional(executor)
2691 .await
2692 }
2693
2694 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2695 where
2696 E: sqlx::Executor<'e, Database = #db>,
2697 {
2698 self.inner.build().execute(executor).await
2699 }
2700 }
2701 };
2702
2703 let final_type: TokenStream2 = if let Some(model) = model_opt {
2704 if let Some(scalar_ty) = scalar_model_ty {
2705 quote! { #scalar_ty }
2706 } else {
2707 quote! { #model }
2708 }
2709 } else {
2710 quote! {}
2711 };
2712 let trait_impl = if model_opt.is_none() {
2713 quote! {
2714 impl sql_forge::SqlForgeQueryExecute
2715 for #query_ident
2716 {
2717 type Db = #db;
2718
2719 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2720 where
2721 Self: Sized + 'e,
2722 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2723 #db: 'e,
2724 {
2725 #query_ident::execute(self, executor)
2726 }
2727 }
2728 }
2729 } else {
2730 quote! {
2731 impl sql_forge::SqlForgeQuery<#final_type>
2732 for #query_ident
2733 {
2734 type Db = #db;
2735
2736 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2737 where
2738 Self: Sized + 'e,
2739 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2740 #db: 'e,
2741 {
2742 #query_ident::fetch_all(self, executor)
2743 }
2744
2745 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2746 where
2747 Self: Sized + 'e,
2748 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2749 #db: 'e,
2750 {
2751 #query_ident::fetch_one(self, executor)
2752 }
2753
2754 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2755 where
2756 Self: Sized + 'e,
2757 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2758 #db: 'e,
2759 {
2760 #query_ident::fetch_optional(self, executor)
2761 }
2762
2763 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2764 where
2765 Self: Sized + 'e,
2766 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2767 #db: 'e,
2768 {
2769 #query_ident::execute(self, executor)
2770 }
2771 }
2772 }
2773 };
2774
2775 generated_query_defs.push(quote! {
2776 struct #query_ident {
2777 inner: sqlx::QueryBuilder<#db>,
2778 }
2779
2780 impl #query_ident {
2781 #exec_methods
2782 }
2783
2784 #trait_impl
2785 });
2786
2787 generated_query_values.push(quote! {
2788 #( #runtime_param_bindings )*
2789 #( #flag_bindings )*
2790 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2791 #( #runtime_steps )*
2792 let #query_value_ident = #query_ident { inner: __builder };
2793 });
2794
2795 if let Some(key) = result_key {
2796 let method_ident = format_ident!("{}", key);
2797 group_field_defs.push(quote! {
2798 #method_ident: #query_ident
2799 });
2800 group_field_tys.push(quote! { #query_ident });
2801
2802 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2803 group_trait_impls.push(quote! {
2804 struct #key_ty_ident;
2805
2806 impl sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup {
2807 type Query = #query_ident;
2808
2809 fn get(self, _: #key_ty_ident) -> Self::Query {
2810 self.#method_ident
2811 }
2812 }
2813 });
2814 group_field_idents.push(method_ident);
2815 }
2816 }
2817
2818 let validator_tokens = quote! {
2820 let _sql_forge_validator = || {
2821 #( #validator_param_bindings )*
2822 #( #grouped_validator_invocations )*
2823 };
2824 };
2825
2826 if !is_grouped_result {
2827 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2828 return quote! {
2829 {
2830 #validator_tokens
2831 #( #generated_query_defs )*
2832 #( #generated_query_values )*
2833 #single_query_value_ident
2834 }
2835 }
2836 .into();
2837 }
2838
2839 let group_field_inits: Vec<TokenStream2> = result_cases
2840 .iter()
2841 .filter_map(|(key, _, _)| key.as_ref())
2842 .map(|key| {
2843 let method_ident = format_ident!("{}", key);
2844 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2845 quote! { #method_ident: #query_value_ident }
2846 })
2847 .collect();
2848
2849 quote! {
2850 {
2851 #validator_tokens
2852
2853 #( #generated_query_defs )*
2854 #( #generated_query_values )*
2855
2856 struct __SqlForgeQueryGroup {
2857 #( #group_field_defs, )*
2858 }
2859
2860 impl __SqlForgeQueryGroup {
2861 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2862 ( #( self.#group_field_idents ),* )
2863 }
2864 }
2865
2866 impl sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup {
2867 type Db = #db;
2868 }
2869
2870 #( #group_trait_impls )*
2871
2872 __SqlForgeQueryGroup {
2873 #( #group_field_inits, )*
2874 }
2875 }
2876 }
2877 .into()
2878}
2879
2880#[proc_macro]
2894pub fn db_type(input: TokenStream) -> TokenStream {
2895 if !input.is_empty() {
2896 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2897 .to_compile_error()
2898 .into();
2899 }
2900
2901 match resolve_db_from_env() {
2902 Ok(db) => quote! { #db }.into(),
2903 Err(msg) => syn::Error::new(Span::call_site(), msg)
2904 .to_compile_error()
2905 .into(),
2906 }
2907}
2908
2909#[proc_macro_attribute]
2924pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2925 let input: ItemStruct = match syn::parse(item) {
2926 Ok(v) => v,
2927 Err(err) => return err.to_compile_error().into(),
2928 };
2929
2930 let struct_name = &input.ident;
2931 let inner_type = match &input.fields {
2932 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2933 _ => {
2934 return syn::Error::new(
2935 input.span(),
2936 "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2937 )
2938 .to_compile_error()
2939 .into();
2940 }
2941 };
2942
2943 let attrs = input.attrs;
2944 let generics = &input.generics;
2945 let vis = &input.vis;
2946 let struct_token = input.struct_token;
2947 let semi_token = input.semi_token;
2948 let fields = &input.fields;
2949
2950 let expanded = quote! {
2951 #( #attrs )*
2952 #[derive(sqlx::Type)]
2953 #[sqlx(transparent)]
2954 #vis #struct_token #struct_name #generics #fields #semi_token
2955
2956 impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2957 fn sql_forge_validator_value(&self) -> #inner_type {
2958 self.0.clone()
2959 }
2960 }
2961 };
2962
2963 expanded.into()
2964}