1use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16 Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17 Stmt, Token, Type,
18};
19
20mod kw {
25 syn::custom_keyword!(scalar);
26}
27
28#[derive(Clone)]
30struct ParamAssign {
31 name: Ident,
32 expr: Expr,
33}
34
35impl Parse for ParamAssign {
36 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37 input.parse::<Token![:]>()?;
38 let name: Ident = input.parse()?;
39 input.parse::<Token![=]>()?;
40 let expr: Expr = input.parse()?;
41 Ok(Self { name, expr })
42 }
43}
44
45#[derive(Clone)]
47struct SectionFragment {
48 sql: String,
49 span: Span,
50 params: ParamsSource,
51}
52
53#[derive(Clone)]
55struct SectionMatchArm {
56 pat: Pat,
57 guard: Option<Expr>,
58 value: SectionValue,
59}
60
61#[derive(Clone)]
63enum SectionValue {
64 Single(SectionFragment),
66 Grouped(Vec<SectionValue>),
68 Match {
70 expr: Expr,
71 arms: Vec<SectionMatchArm>,
72 },
73}
74
75#[derive(Clone)]
77struct SectionAssign {
78 names: Vec<Ident>,
79 value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 input.parse::<Token![#]>()?;
85
86 let names = if input.peek(syn::token::Paren) {
88 let content;
89 syn::parenthesized!(content in input);
90 let mut out = Vec::new();
91 while !content.is_empty() {
92 out.push(content.parse::<Ident>()?);
93 if content.is_empty() {
94 break;
95 }
96 content.parse::<Token![,]>()?;
97 }
98 if out.is_empty() {
99 return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100 }
101 out
102 } else {
103 vec![input.parse::<Ident>()?]
104 };
105
106 input.parse::<Token![=]>()?;
107 let value = parse_section_value(input, names.len())?;
108 Ok(Self { names, value })
109 }
110}
111
112struct SqlForgeInput {
114 db: Option<Type>,
115 result: ResultSpec,
116 force_scalar: bool,
117 sql: SqlTemplate,
118 params: ParamsSource,
119 sections: Vec<SectionAssign>,
120 batch: Option<Expr>,
121}
122
123#[derive(Clone)]
125struct ResultAssign {
126 name: Ident,
127 model: Type,
128 force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133 None,
135 Single(Box<Type>),
137 Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143 None,
144 Map(Vec<ParamAssign>),
146 Struct(Box<Expr>),
148}
149
150enum SqlTemplate {
152 Literal(LitStr),
153}
154
155impl SqlTemplate {
156 fn span(&self) -> Span {
157 match self {
158 Self::Literal(lit) => lit.span(),
159 }
160 }
161
162 fn into_segments(self) -> Result<Vec<Segment>, String> {
164 match self {
165 Self::Literal(lit) => parse_literal_segments(&lit.value()),
166 }
167 }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171 if input.peek(LitStr) {
172 Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173 } else {
174 Err(input.error("sql_forge!: SQL template must be a string literal"))
175 }
176}
177
178#[derive(Clone)]
180enum Segment {
181 Text(String),
183 Section { name: String },
185 Batch { parts: Vec<TextPart> },
187}
188
189#[derive(Clone)]
191enum TextPart {
192 Lit(String),
194 Param { name: String, is_list: bool },
196}
197
198enum MapKind {
204 Results,
205 Params,
206 Sections,
207}
208
209fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
215 let fork = input.fork();
216 let content;
217 syn::parenthesized!(content in fork);
218
219 if content.is_empty() {
220 return Err(input.error("sql_forge!: map argument cannot be empty"));
221 }
222
223 if content.peek(Token![>]) {
224 Ok(Some(MapKind::Results))
225 } else if content.peek(Token![:]) {
226 Ok(Some(MapKind::Params))
227 } else if content.peek(Token![#]) {
228 Ok(Some(MapKind::Sections))
229 } else {
230 Ok(None)
231 }
232}
233
234impl Parse for ResultAssign {
235 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
236 input.parse::<Token![>]>()?;
237 let name: Ident = input.parse()?;
238 input.parse::<Token![=]>()?;
239 let (force_scalar, model) = if input.peek(kw::scalar) {
240 input.parse::<kw::scalar>()?;
241 (true, input.parse::<Type>()?)
242 } else {
243 (false, input.parse::<Type>()?)
244 };
245 Ok(Self {
246 name,
247 model,
248 force_scalar,
249 })
250 }
251}
252
253fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
255 let content;
256 syn::parenthesized!(content in input);
257
258 let mut results = Vec::new();
259 while !content.is_empty() {
260 results.push(content.parse::<ResultAssign>()?);
261 if content.is_empty() {
262 break;
263 }
264 content.parse::<Token![,]>()?;
265 }
266
267 if results.is_empty() {
268 return Err(input.error("sql_forge!: result map cannot be empty"));
269 }
270
271 Ok(results)
272}
273
274fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
276 let content;
277 syn::parenthesized!(content in input);
278
279 let mut params = Vec::new();
280 while !content.is_empty() {
281 params.push(content.parse::<ParamAssign>()?);
282 if content.is_empty() {
283 break;
284 }
285 content.parse::<Token![,]>()?;
286 }
287
288 Ok(params)
289}
290
291fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
293 let content;
294 syn::parenthesized!(content in input);
295
296 let mut sections = Vec::new();
297 while !content.is_empty() {
298 sections.push(content.parse::<SectionAssign>()?);
299 if content.is_empty() {
300 break;
301 }
302 content.parse::<Token![,]>()?;
303 }
304
305 Ok(sections)
306}
307
308fn parse_params_source_expr(input: ParseStream<'_>) -> syn::Result<ParamsSource> {
310 if input.peek(syn::token::Paren) {
311 match detect_parenthesized_map_kind(input)? {
312 Some(MapKind::Results) => Err(input
313 .error("sql_forge!: result maps are only allowed as the macro result argument")),
314 Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
315 Some(MapKind::Sections) => Err(input.error(
316 "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317 )),
318 None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319 }
320 } else {
321 Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322 }
323}
324
325fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327 if input.peek(syn::token::Paren) {
328 let fork = input.fork();
329 let content;
330 syn::parenthesized!(content in fork);
331
332 if let Ok(first_expr) = content.parse::<Expr>() {
333 if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334 let _ = parse_params_source_expr(&content)?;
335 if content.peek(Token![,]) {
336 content.parse::<Token![,]>()?;
337 }
338 if content.is_empty() {
339 let content;
340 syn::parenthesized!(content in input);
341 let first_expr: Expr = content.parse()?;
342 let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343 input.error("sql_forge!: section tuple must start with a string literal")
344 })?;
345 let span = first_expr.span();
346 content.parse::<Token![,]>()?;
347 let params = parse_params_source_expr(&content)?;
348 if content.peek(Token![,]) {
349 content.parse::<Token![,]>()?;
350 }
351 if !content.is_empty() {
352 return Err(content.error(
353 "sql_forge!: unexpected tokens after section-local parameter source",
354 ));
355 }
356 return Ok(SectionFragment { sql, span, params });
357 }
358 }
359 }
360 }
361
362 let expr: Expr = input.parse()?;
363 let sql = extract_lit_str(&expr).ok_or_else(|| {
364 input
365 .error("sql_forge!: section values must be string literals or (string literal, params)")
366 })?;
367 Ok(SectionFragment {
368 sql,
369 span: expr.span(),
370 params: ParamsSource::None,
371 })
372}
373
374fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
376 if input.peek(Token![match]) {
377 input.parse::<Token![match]>()?;
378 let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
379 let content;
380 syn::braced!(content in input);
381 let mut arms = Vec::new();
382 while !content.is_empty() {
383 let pat = content.call(Pat::parse_multi_with_leading_vert)?;
384 let guard = if content.peek(Token![if]) {
385 content.parse::<Token![if]>()?;
386 Some(content.parse::<Expr>()?)
387 } else {
388 None
389 };
390 content.parse::<Token![=>]>()?;
391 let value = parse_section_value(&content, width)?;
393 if content.peek(Token![,]) {
394 content.parse::<Token![,]>()?;
395 }
396 arms.push(SectionMatchArm { pat, guard, value });
397 }
398 return Ok(SectionValue::Match { expr, arms });
399 }
400
401 if width == 1 {
402 return Ok(SectionValue::Single(parse_section_fragment(input)?));
403 }
404
405 let content;
406 syn::parenthesized!(content in input);
407 let mut items = Vec::new();
408 while !content.is_empty() {
409 items.push(parse_section_value(&content, 1)?);
412 if content.is_empty() {
413 break;
414 }
415 content.parse::<Token![,]>()?;
416 }
417
418 if items.len() != width {
419 return Err(input.error(format!(
420 "sql_forge!: grouped section value must provide exactly {} items",
421 width,
422 )));
423 }
424
425 Ok(SectionValue::Grouped(items))
426}
427
428impl Parse for SqlForgeInput {
433 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
434 let (db, result, force_scalar, sql) = if input.peek(LitStr) {
435 let sql = parse_sql_template(input)?;
436 (None, ResultSpec::None, false, sql)
437 } else if input.peek(kw::scalar) {
438 input.parse::<kw::scalar>()?;
439 let model: Type = input.parse()?;
440 input.parse::<Token![,]>()?;
441 let sql = parse_sql_template(input)?;
442 (None, ResultSpec::Single(Box::new(model)), true, sql)
443 } else if input.peek(syn::token::Paren) {
444 let result_map_kind = detect_parenthesized_map_kind(input)?;
445 match result_map_kind {
446 Some(MapKind::Results) => {
447 let result = ResultSpec::Group(parse_result_map(input)?);
448 input.parse::<Token![,]>()?;
449 let sql = parse_sql_template(input)?;
450 (None, result, false, sql)
451 }
452 _ => {
453 return Err(input.error(
454 "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
455 ));
456 }
457 }
458 } else {
459 let first_ty: Type = input.parse()?;
460 input.parse::<Token![,]>()?;
461
462 if input.peek(LitStr) && is_db_type(&first_ty) {
463 let db = first_ty;
464 let sql = parse_sql_template(input)?;
465 (Some(db), ResultSpec::None, false, sql)
466 } else if input.peek(LitStr) {
467 let model = first_ty;
468 let sql = parse_sql_template(input)?;
469 (None, ResultSpec::Single(Box::new(model)), false, sql)
470 } else if input.peek(kw::scalar) {
471 input.parse::<kw::scalar>()?;
472 let model: Type = input.parse()?;
473 input.parse::<Token![,]>()?;
474 let sql = parse_sql_template(input)?;
475 (
476 Some(first_ty),
477 ResultSpec::Single(Box::new(model)),
478 true,
479 sql,
480 )
481 } else if input.peek(syn::token::Paren)
482 && matches!(
483 detect_parenthesized_map_kind(input)?,
484 Some(MapKind::Results)
485 )
486 {
487 let result = ResultSpec::Group(parse_result_map(input)?);
488 input.parse::<Token![,]>()?;
489 let sql = parse_sql_template(input)?;
490 (Some(first_ty), result, false, sql)
491 } else {
492 let db = Some(first_ty);
493 let model: Type = input.parse()?;
494 input.parse::<Token![,]>()?;
495 let sql = parse_sql_template(input)?;
496 (db, ResultSpec::Single(Box::new(model)), false, sql)
497 }
498 };
499
500 let mut batch = None;
501 let mut params = ParamsSource::None;
502 let mut sections = Vec::new();
503 let mut seen_params = false;
504 let mut seen_sections = false;
505
506 if input.parse::<Token![,]>().is_ok() {
507 while !input.is_empty() {
508 if input.peek(Token![..]) {
509 if batch.is_some() {
510 return Err(
511 input.error("sql_forge!: only one batch source argument is allowed")
512 );
513 }
514 input.parse::<Token![..]>()?;
515 batch = Some(input.parse::<Expr>()?);
516 } else if input.peek(syn::token::Paren) {
517 match detect_parenthesized_map_kind(input)? {
518 Some(MapKind::Results) => {
519 return Err(input.error(
520 "sql_forge!: result maps are only allowed as the macro result argument",
521 ));
522 }
523 Some(MapKind::Params) => {
524 if seen_params {
525 return Err(
526 input.error("sql_forge!: only one parameter source is allowed")
527 );
528 }
529 params = ParamsSource::Map(parse_param_map(input)?);
530 seen_params = true;
531 }
532 Some(MapKind::Sections) => {
533 if seen_sections {
534 return Err(
535 input.error("sql_forge!: duplicate section map argument")
536 );
537 }
538 sections = parse_section_map(input)?;
539 seen_sections = true;
540 }
541 None => {
542 if seen_params {
543 return Err(
544 input.error("sql_forge!: only one parameter source is allowed")
545 );
546 }
547 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548 seen_params = true;
549 }
550 }
551 } else {
552 if seen_params {
553 return Err(input.error("sql_forge!: only one parameter source is allowed"));
554 }
555 params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
556 seen_params = true;
557 }
558
559 if input.parse::<Token![,]>().is_ok() {
560 continue;
561 }
562 break;
563 }
564 }
565
566 if !input.is_empty() {
567 return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
568 }
569
570 Ok(Self {
571 db,
572 result,
573 force_scalar,
574 sql,
575 params,
576 sections,
577 batch,
578 })
579 }
580}
581
582fn resolve_db_from_env() -> Result<Type, String> {
588 if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
589 return syn::parse_str::<Type>(&val).map_err(|err| {
590 format!(
591 "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
592 val, err
593 )
594 });
595 }
596
597 let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
598 Ok(d) => d,
599 Err(_) => {
600 return Err(
601 "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
602 or configure [package.metadata.sql_forge] in Cargo.toml"
603 .to_string(),
604 );
605 }
606 };
607 let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
608
609 let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
610 format!(
611 "sql_forge!: failed to read {}: {}",
612 manifest_path.display(),
613 err
614 )
615 })?;
616
617 let value: toml::Value = toml::from_str(&cargo_toml)
618 .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
619
620 let db_str = value
621 .get("package")
622 .and_then(|v| v.get("metadata"))
623 .and_then(|v| v.get("sql_forge"))
624 .and_then(|v| v.get("db"))
625 .and_then(|v| v.as_str())
626 .ok_or({
627 "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
628 SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
629 })?;
630
631 syn::parse_str::<Type>(db_str).map_err(|err| {
632 format!(
633 "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
634 db_str, err
635 )
636 })
637}
638
639fn uses_dollar_params(db: &Type) -> bool {
641 let Type::Path(type_path) = db else {
642 return false;
643 };
644 type_path
645 .path
646 .segments
647 .last()
648 .is_some_and(|s| s.ident == "Postgres")
649}
650
651fn is_db_type(ty: &Type) -> bool {
653 let Type::Path(type_path) = ty else {
654 return false;
655 };
656 if type_path.qself.is_some() {
657 return false;
658 }
659 let segs = &type_path.path.segments;
660 if segs.len() != 2 {
661 return false;
662 }
663 segs[0].ident == "sqlx"
664 && ["MySql", "Postgres", "Sqlite"].contains(&segs[1].ident.to_string().as_str())
665}
666
667fn is_builtin_scalar_type(ty: &Type) -> bool {
669 let Type::Path(type_path) = ty else {
670 return false;
671 };
672
673 if type_path.qself.is_some()
674 || type_path.path.leading_colon.is_some()
675 || type_path.path.segments.len() != 1
676 {
677 return false;
678 }
679
680 let ident = &type_path.path.segments[0].ident;
681 ident == "i8"
682 || ident == "i16"
683 || ident == "i32"
684 || ident == "i64"
685 || ident == "isize"
686 || ident == "u8"
687 || ident == "u16"
688 || ident == "u32"
689 || ident == "u64"
690 || ident == "usize"
691 || ident == "f32"
692 || ident == "f64"
693 || ident == "bool"
694 || ident == "String"
695}
696
697fn scalar_output_type(model: &Type) -> Option<&Type> {
699 if is_builtin_scalar_type(model) {
700 return Some(model);
701 }
702 None
703}
704
705fn push_text_segment(out: &mut Vec<Segment>, text: String) {
707 if text.is_empty() {
708 return;
709 }
710 match out.last_mut() {
711 Some(Segment::Text(existing)) => existing.push_str(&text),
712 _ => out.push(Segment::Text(text)),
713 }
714}
715
716fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
718 let mut out = Vec::new();
719 let mut text = String::new();
720 let mut chars = sql.chars().peekable();
721
722 while let Some(ch) = chars.next() {
723 if ch != '{' {
724 text.push(ch);
725 continue;
726 }
727
728 if chars.peek() == Some(&'(') {
729 push_text_segment(&mut out, std::mem::take(&mut text));
730
731 let mut paren_depth = 0u32;
732 let mut content = String::new();
733 let mut found_close = false;
734 for ch in chars.by_ref() {
735 if ch == '{' {
736 return Err(
737 "sql_forge!: nested braces not allowed inside batch section".to_string()
738 );
739 }
740 if ch == '}' {
741 if paren_depth != 0 {
742 return Err(
743 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
744 .to_string(),
745 );
746 }
747 found_close = true;
748 break;
749 }
750 if ch == '(' {
751 paren_depth += 1;
752 } else if ch == ')' {
753 if paren_depth == 0 {
754 return Err(
755 "sql_forge!: batch section {( ... )} has unbalanced parentheses"
756 .to_string(),
757 );
758 }
759 paren_depth -= 1;
760 }
761 content.push(ch);
762 }
763 if !found_close {
764 return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
765 }
766 let parts = parse_text_parts(&content);
767 for part in &parts {
768 if let TextPart::Param { is_list: true, .. } = part {
769 return Err(
770 "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
771 batch sections; use plain parameters (:name) instead"
772 .to_string(),
773 );
774 }
775 }
776 out.push(Segment::Batch { parts });
777 continue;
778 }
779
780 if chars.peek() != Some(&'#') {
781 text.push(ch);
782 continue;
783 }
784
785 chars.next();
786 push_text_segment(&mut out, std::mem::take(&mut text));
787
788 let mut name = String::new();
789 loop {
790 let Some(next) = chars.next() else {
791 return Err("sql_forge!: section placeholder without closing }".to_string());
792 };
793 if next == '}' {
794 break;
795 }
796 name.push(next);
797 }
798
799 if name.is_empty() {
800 return Err("sql_forge!: empty section placeholder name".to_string());
801 }
802
803 out.push(Segment::Section { name });
804 }
805
806 push_text_segment(&mut out, text);
807 Ok(out)
808}
809
810fn is_ident_start(ch: char) -> bool {
816 ch == '_' || ch.is_ascii_alphabetic()
817}
818
819fn is_ident_continue(ch: char) -> bool {
821 is_ident_start(ch) || ch.is_ascii_digit()
822}
823
824fn sanitize_backticked_alias_ident(content: &str) -> String {
826 let mut split_at = content.len();
827 for (idx, ch) in content.char_indices() {
828 if ch == '!' || ch == '?' || ch == ':' {
829 split_at = idx;
830 break;
831 }
832 }
833
834 if split_at == content.len() {
835 return content.to_string();
836 }
837
838 let base = content[..split_at].trim_end();
839 if base.is_empty() {
840 content.to_string()
841 } else {
842 base.to_string()
843 }
844}
845
846fn sanitize_runtime_sql_text(text: &str) -> String {
848 let mut out = String::with_capacity(text.len());
849 let mut chars = text.chars().peekable();
850
851 while let Some(ch) = chars.next() {
852 if ch != '`' {
853 out.push(ch);
854 continue;
855 }
856
857 let mut content = String::new();
858 let mut closed = false;
859
860 for next in chars.by_ref() {
861 if next == '`' {
862 closed = true;
863 break;
864 }
865 content.push(next);
866 }
867
868 if closed {
869 out.push('`');
870 out.push_str(&sanitize_backticked_alias_ident(&content));
871 out.push('`');
872 } else {
873 out.push('`');
874 out.push_str(&content);
875 break;
876 }
877 }
878
879 out
880}
881
882fn parse_text_parts(text: &str) -> Vec<TextPart> {
884 let mut parts = Vec::new();
885 let mut last = 0usize;
886 let mut iter = text.char_indices().peekable();
887
888 while let Some((idx, ch)) = iter.next() {
889 if ch != ':' {
890 continue;
891 }
892
893 let Some(&(next_idx, next_ch)) = iter.peek() else {
894 continue;
895 };
896
897 if !is_ident_start(next_ch) {
898 continue;
899 }
900
901 if text[..idx].ends_with(':') {
902 continue;
903 }
904
905 if last < idx {
906 parts.push(TextPart::Lit(text[last..idx].to_string()));
907 }
908
909 iter.next();
910
911 let mut name = String::new();
912 name.push(next_ch);
913 let mut end = next_idx + next_ch.len_utf8();
914
915 while let Some(&(j, c)) = iter.peek() {
916 if is_ident_continue(c) {
917 name.push(c);
918 end = j + c.len_utf8();
919 iter.next();
920 } else {
921 break;
922 }
923 }
924
925 let mut is_list = false;
926 if text[end..].starts_with("[]") {
927 is_list = true;
928 end += 2;
929 }
930
931 parts.push(TextPart::Param { name, is_list });
932 last = end;
933 }
934
935 if last < text.len() {
936 parts.push(TextPart::Lit(text[last..].to_string()));
937 }
938
939 parts
940}
941
942fn render_validator_text(
944 text: &str,
945 use_dollar_params: bool,
946 param_offset: &mut usize,
947 list_count: usize,
948) -> (String, Vec<(String, bool)>) {
949 let mut out_sql = String::new();
950 let mut occurrences = Vec::new();
951
952 for part in parse_text_parts(text) {
953 match part {
954 TextPart::Lit(lit) => out_sql.push_str(&lit),
955 TextPart::Param { name, is_list } => {
956 if is_list && list_count > 1 {
957 let slots: Vec<String> = if use_dollar_params {
958 (0..list_count)
959 .map(|i| format!("${}", *param_offset + i + 1))
960 .collect()
961 } else {
962 (0..list_count).map(|_| "?".to_string()).collect()
963 };
964 if use_dollar_params {
965 *param_offset += list_count;
966 }
967 out_sql.push_str(&slots.join(", "));
968 } else if use_dollar_params {
969 *param_offset += 1;
970 write!(out_sql, "${}", *param_offset).unwrap();
971 } else {
972 out_sql.push('?');
973 }
974 occurrences.push((name, is_list));
975 }
976 }
977 }
978
979 (out_sql, occurrences)
980}
981
982fn strip_expr(expr: &Expr) -> &Expr {
984 match expr {
985 Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
986 Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
987 Expr::Block(ExprBlock { block, .. }) => {
988 if block.stmts.len() != 1 {
989 return expr;
990 }
991 match &block.stmts[0] {
992 Stmt::Expr(inner, None) => strip_expr(inner),
993 _ => expr,
994 }
995 }
996 _ => expr,
997 }
998}
999
1000fn extract_lit_str(expr: &Expr) -> Option<String> {
1002 match strip_expr(expr) {
1003 Expr::Lit(ExprLit {
1004 lit: Lit::Str(lit), ..
1005 }) => Some(lit.value()),
1006 _ => None,
1007 }
1008}
1009
1010fn result_flag_ident(name: &str) -> syn::Ident {
1016 format_ident!("__enhanced_result_flag_{}", name)
1017}
1018
1019fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
1023 fn walk(stream: TokenStream2) -> TokenStream2 {
1024 let mut out = TokenStream2::new();
1025 let iter = stream.into_iter().peekable();
1026
1027 for token in iter {
1028 match token {
1029 TokenTree::Group(group) => {
1030 if group.delimiter() == Delimiter::Brace {
1031 let mut inner = group.stream().into_iter();
1032 let first = inner.next();
1033 let second = inner.next();
1034 let third = inner.next();
1035
1036 if let (
1037 Some(TokenTree::Punct(p)),
1038 Some(TokenTree::Ident(name_ident)),
1039 None,
1040 ) = (first, second, third)
1041 {
1042 if p.as_char() == '>' {
1043 let ident = result_flag_ident(&name_ident.to_string());
1044 out.extend(std::iter::once(TokenTree::Ident(ident)));
1045 continue;
1046 }
1047 }
1048 }
1049
1050 let new_inner = walk(group.stream());
1051 let mut new_group = Group::new(group.delimiter(), new_inner);
1052 new_group.set_span(group.span());
1053 out.extend(std::iter::once(TokenTree::Group(new_group)));
1054 }
1055 other => out.extend(std::iter::once(other)),
1056 }
1057 }
1058
1059 out
1060 }
1061
1062 walk(input)
1063}
1064
1065fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1067 keys.iter()
1068 .map(|key| {
1069 let ident = result_flag_ident(key);
1070 let enabled = Some(key.as_str()) == active_key;
1071 quote! { let #ident: bool = #enabled; }
1072 })
1073 .collect()
1074}
1075
1076fn transpose_section_case_matrix(
1078 case_matrix: Vec<Vec<SectionFragment>>,
1079 width: usize,
1080) -> Result<Vec<Vec<SectionFragment>>, String> {
1081 let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1082
1083 for row in case_matrix {
1084 if row.len() != width {
1085 return Err(
1086 "sql_forge!: grouped sections must return one item per section".to_string(),
1087 );
1088 }
1089 for (section_idx, fragment) in row.into_iter().enumerate() {
1090 per_section[section_idx].push(fragment);
1091 }
1092 }
1093
1094 Ok(per_section)
1095}
1096
1097fn collect_section_case_matrix(
1099 value: SectionValue,
1100 width: usize,
1101 active_key: Option<&str>,
1102) -> Result<Vec<Vec<SectionFragment>>, String> {
1103 match value {
1104 SectionValue::Single(fragment) => {
1105 if width != 1 {
1106 return Err(
1107 "sql_forge!: grouped sections must return one item per section".to_string(),
1108 );
1109 }
1110 Ok(vec![vec![fragment]])
1111 }
1112 SectionValue::Grouped(values) => {
1113 if values.len() != width {
1114 return Err(
1115 "sql_forge!: grouped sections must return one item per section".to_string(),
1116 );
1117 }
1118
1119 let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1120 let mut nmax = 1usize;
1121
1122 for value in values {
1123 let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1124 let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1125 for mut row in item_matrix {
1126 let fragment = row.pop().ok_or_else(|| {
1127 "sql_forge!: grouped sections must return one item per section".to_string()
1128 })?;
1129 if !row.is_empty() {
1130 return Err(
1131 "sql_forge!: grouped sections must return one item per section"
1132 .to_string(),
1133 );
1134 }
1135 item_variants.push(fragment);
1136 }
1137 if item_variants.is_empty() {
1138 return Err("sql_forge!: section match must have at least one arm".to_string());
1139 }
1140 nmax = nmax.max(item_variants.len());
1141 variants_by_section.push(item_variants);
1142 }
1143
1144 let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1145 for case_idx in 0..nmax {
1146 let mut row = Vec::<SectionFragment>::with_capacity(width);
1147 for variants in &variants_by_section {
1148 row.push(variants[case_idx % variants.len()].clone());
1149 }
1150 case_matrix.push(row);
1151 }
1152
1153 Ok(case_matrix)
1154 }
1155 SectionValue::Match { expr, arms } => {
1156 let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1157
1158 if let Some(key) = expr_result_flag_key(&expr) {
1159 let target = active_key == Some(key.as_str());
1160 for arm in arms {
1161 if arm.guard.is_none() {
1162 if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1163 continue;
1164 }
1165 }
1166 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1167 wrap_section_case_matrix_for_match_arm(
1168 &mut arm_cases,
1169 &expr,
1170 &arm.pat,
1171 arm.guard.as_ref(),
1172 );
1173 case_matrix.extend(arm_cases);
1174 }
1175 } else {
1176 for arm in arms {
1177 let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1178 wrap_section_case_matrix_for_match_arm(
1179 &mut arm_cases,
1180 &expr,
1181 &arm.pat,
1182 arm.guard.as_ref(),
1183 );
1184 case_matrix.extend(arm_cases);
1185 }
1186 }
1187
1188 if case_matrix.is_empty() {
1189 return Err("sql_forge!: section match must have at least one arm".to_string());
1190 }
1191
1192 Ok(case_matrix)
1193 }
1194 }
1195}
1196
1197fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1200 let match_expr = match_expr.clone();
1201 let pat = pat.clone();
1202 let pattern_binds_values = match &pat {
1203 Pat::Ident(_) => true,
1204 Pat::Or(pat_or) => pat_or
1205 .cases
1206 .iter()
1207 .any(|case| matches!(case, Pat::Ident(_))),
1208 Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1209 Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1210 Pat::Slice(pat_slice) => pat_slice
1211 .elems
1212 .iter()
1213 .any(|elem| matches!(elem, Pat::Ident(_))),
1214 Pat::Struct(pat_struct) => pat_struct
1215 .fields
1216 .iter()
1217 .any(|field| matches!(*field.pat, Pat::Ident(_))),
1218 Pat::Tuple(pat_tuple) => pat_tuple
1219 .elems
1220 .iter()
1221 .any(|elem| matches!(elem, Pat::Ident(_))),
1222 Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1223 .elems
1224 .iter()
1225 .any(|elem| matches!(elem, Pat::Ident(_))),
1226 Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1227 _ => false,
1228 };
1229
1230 if pattern_binds_values {
1231 let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1232 .into_iter()
1233 .map(|ident| quote! { let _ = &#ident; })
1234 .collect();
1235 if let Some(guard) = guard.cloned() {
1236 parse_quote! {
1237 match &(#match_expr) {
1238 #pat if #guard => { #( #pat_refs )* #expr },
1239 _ => unreachable!("sql_forge!: validator arm mismatch"),
1240 }
1241 }
1242 } else {
1243 parse_quote! {
1244 match &(#match_expr) {
1245 #pat => { #( #pat_refs )* #expr },
1246 _ => unreachable!("sql_forge!: validator arm mismatch"),
1247 }
1248 }
1249 }
1250 } else if let Some(guard) = guard.cloned() {
1251 parse_quote! {
1252 match &(#match_expr) {
1253 #pat if #guard => { &(#expr) },
1254 _ => unreachable!("sql_forge!: validator arm mismatch"),
1255 }
1256 }
1257 } else {
1258 parse_quote! {
1259 match &(#match_expr) {
1260 #pat => { &(#expr) },
1261 _ => unreachable!("sql_forge!: validator arm mismatch"),
1262 }
1263 }
1264 }
1265}
1266
1267fn wrap_params_source_for_match_arm(
1269 params: &mut ParamsSource,
1270 match_expr: &Expr,
1271 pat: &Pat,
1272 guard: Option<&Expr>,
1273) {
1274 match params {
1275 ParamsSource::None => {}
1276 ParamsSource::Map(entries) => {
1277 for entry in entries {
1278 entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1279 }
1280 }
1281 ParamsSource::Struct(expr) => {
1282 **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1283 }
1284 }
1285}
1286
1287fn wrap_section_case_matrix_for_match_arm(
1289 case_matrix: &mut [Vec<SectionFragment>],
1290 match_expr: &Expr,
1291 pat: &Pat,
1292 guard: Option<&Expr>,
1293) {
1294 for row in case_matrix {
1295 for fragment in row {
1296 wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1297 }
1298 }
1299}
1300
1301fn collect_section_variants(
1311 value: SectionValue,
1312 width: usize,
1313) -> Result<Vec<Vec<SectionFragment>>, String> {
1314 transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1315}
1316
1317fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1319 match strip_expr(expr) {
1320 Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1321 let name = path.path.segments[0].ident.to_string();
1322 name.strip_prefix("__enhanced_result_flag_")
1323 .map(|v| v.to_string())
1324 }
1325 _ => None,
1326 }
1327}
1328
1329fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1331 match pat {
1332 Pat::Lit(expr_lit) => match &expr_lit.lit {
1333 Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1334 _ => None,
1335 },
1336 Pat::Wild(_) => Some(true),
1337 _ => None,
1338 }
1339}
1340
1341fn collect_section_variants_for_result(
1346 value: SectionValue,
1347 width: usize,
1348 active_key: Option<&str>,
1349) -> Result<Vec<Vec<SectionFragment>>, String> {
1350 transpose_section_case_matrix(
1351 collect_section_case_matrix(value, width, active_key)?,
1352 width,
1353 )
1354}
1355
1356fn build_param_bindings(
1364 params: &ParamsSource,
1365 used_param_names: &[String],
1366 prefix: &str,
1367 for_validator: bool,
1368 enforce_usage_check: bool,
1369) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1370 let mut declared_params = HashMap::<String, syn::Ident>::new();
1371 let mut bindings = Vec::<TokenStream2>::new();
1372
1373 match params {
1374 ParamsSource::None => {}
1375 ParamsSource::Map(entries) => {
1376 for entry in entries {
1377 let key = entry.name.to_string();
1378 if declared_params.contains_key(&key) {
1379 return Err(syn::Error::new(
1380 entry.name.span(),
1381 "sql_forge!: duplicated parameter mapping",
1382 )
1383 .to_compile_error()
1384 .into());
1385 }
1386 if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1387 return Err(syn::Error::new(
1388 entry.name.span(),
1389 format!(
1390 "sql_forge!: parameter :{} is unused in the SQL template",
1391 key,
1392 ),
1393 )
1394 .to_compile_error()
1395 .into());
1396 }
1397 let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1398 let expr = &entry.expr;
1399 if for_validator {
1400 bindings.push(quote! {
1401 let #local_ident = &(#expr);
1402 });
1403 } else {
1404 bindings.push(quote! {
1405 let #local_ident = #expr;
1406 });
1407 }
1408 declared_params.insert(key, local_ident);
1409 }
1410 }
1411 ParamsSource::Struct(expr) => {
1412 let source_ident = format_ident!("__enhanced_source_{}", prefix);
1413 bindings.push(quote! {
1414 let #source_ident = &(#expr);
1415 });
1416 for name in used_param_names {
1417 let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1418 let field_ident = format_ident!("{}", name);
1419 if for_validator {
1420 bindings.push(quote! {
1421 let #local_ident = &#source_ident.#field_ident;
1422 });
1423 } else {
1424 bindings.push(quote! {
1425 let #local_ident = #source_ident.#field_ident;
1426 });
1427 }
1428 declared_params.insert(name.to_string(), local_ident);
1429 }
1430 }
1431 }
1432
1433 Ok((declared_params, bindings))
1434}
1435
1436struct ValidatorRenderContext<'a> {
1437 local_params: &'a HashMap<String, syn::Ident>,
1438 top_level_params: &'a HashMap<String, syn::Ident>,
1439 allow_top_level_fallback: bool,
1440 use_dollar_params: bool,
1441 sql_span: Span,
1442 list_count: usize,
1443}
1444
1445fn render_validator_args(
1450 sql: &str,
1451 param_offset: &mut usize,
1452 arg_index: &mut usize,
1453 context: &ValidatorRenderContext<'_>,
1454) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1455 let (rendered_sql, occurrences) = render_validator_text(
1456 sql,
1457 context.use_dollar_params,
1458 param_offset,
1459 context.list_count,
1460 );
1461
1462 let mut setup = Vec::<TokenStream2>::new();
1463 let mut args = Vec::<TokenStream2>::new();
1464
1465 for (name, is_list) in occurrences {
1466 let local_ident = if context.allow_top_level_fallback {
1467 context
1468 .local_params
1469 .get(&name)
1470 .or_else(|| context.top_level_params.get(&name))
1471 } else {
1472 context.local_params.get(&name)
1473 };
1474
1475 let Some(local_ident) = local_ident else {
1476 return Err(syn::Error::new(
1477 context.sql_span,
1478 format!("sql_forge!: parameter :{} has no mapping", name),
1479 )
1480 .to_compile_error()
1481 .into());
1482 };
1483
1484 if is_list {
1485 for _ in 0..context.list_count {
1486 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1487 *arg_index += 1;
1488 if context.use_dollar_params {
1489 setup.push(quote! {
1490 let #value_ident = sql_forge::sql_forge_validator_value(
1491 (#local_ident)
1492 .as_slice()
1493 .first()
1494 .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1495 );
1496 });
1497 } else {
1498 setup.push(quote! {
1499 let #value_ident = (#local_ident)
1500 .as_slice()
1501 .first()
1502 .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1503 });
1504 }
1505 args.push(quote! { #value_ident });
1506 }
1507 } else {
1508 let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1509 *arg_index += 1;
1510 if context.use_dollar_params {
1511 setup.push(quote! {
1512 let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1513 });
1514 } else {
1515 setup.push(quote! {
1516 let #value_ident = #local_ident;
1517 });
1518 }
1519 args.push(quote! { #value_ident });
1520 }
1521 }
1522
1523 Ok((rendered_sql, setup, args))
1524}
1525
1526fn render_runtime_fragment(
1533 fragment: &SectionFragment,
1534 local_params: &HashMap<String, syn::Ident>,
1535) -> Result<TokenStream2, TokenStream> {
1536 let mut steps = Vec::<TokenStream2>::new();
1537
1538 for part in parse_text_parts(&fragment.sql) {
1539 match part {
1540 TextPart::Lit(lit) => {
1541 let lit_str = LitStr::new(&lit, fragment.span);
1542 steps.push(quote! { __builder.push(#lit_str); });
1543 }
1544 TextPart::Param { name, is_list } => {
1545 let Some(local_ident) = local_params.get(&name) else {
1546 return Err(syn::Error::new(
1547 fragment.span,
1548 format!("sql_forge!: parameter :{} has no mapping", name),
1549 )
1550 .to_compile_error()
1551 .into());
1552 };
1553
1554 if is_list {
1555 steps.push(quote! {
1556 let __enhanced_values = #local_ident;
1557 let mut __separated = __builder.separated(", ");
1558 for __value in __enhanced_values {
1559 __separated.push_bind(__value);
1560 }
1561 });
1562 } else {
1563 steps.push(quote! {
1564 __builder.push_bind(#local_ident);
1565 });
1566 }
1567 }
1568 }
1569 }
1570
1571 Ok(quote! { #( #steps )* })
1572}
1573
1574fn is_pat_binding(ident: &Ident) -> bool {
1576 let name = ident.to_string();
1577 !name.is_empty()
1578 && name
1579 .chars()
1580 .next()
1581 .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1582}
1583
1584fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1586 let mut names = Vec::new();
1587 fn walk(p: &Pat, names: &mut Vec<Ident>) {
1588 match p {
1589 Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1590 Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1591 Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1592 Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1593 Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1594 Pat::Paren(pp) => walk(&pp.pat, names),
1595 Pat::Reference(pr) => walk(&pr.pat, names),
1596 Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1597 Pat::Type(pt) => walk(&pt.pat, names),
1598 _ => {}
1599 }
1600 }
1601 walk(pat, &mut names);
1602 names
1603}
1604
1605fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1607 match value {
1608 SectionValue::Single(f) => {
1609 if collect_used_param_names_in_sql(&f.sql)
1610 .iter()
1611 .any(|n| n == name)
1612 {
1613 return true;
1614 }
1615 if let ParamsSource::Map(entries) = &f.params {
1616 for e in entries {
1617 let expr = &e.expr;
1618 let expr_str = quote! { #expr }.to_string();
1619 if expr_str.trim() == name {
1620 return true;
1621 }
1622 }
1623 }
1624 false
1625 }
1626 SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1627 SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1628 let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1629 .into_iter()
1630 .map(|i| i.to_string())
1631 .collect();
1632 if pat_vars.contains(name) {
1633 false
1634 } else {
1635 section_value_refers_to(&arm.value, name)
1636 }
1637 }),
1638 }
1639}
1640
1641fn build_section_runtime_action(
1643 value: &SectionValue,
1644 section_idx: usize,
1645 prefix: &str,
1646) -> Result<TokenStream2, TokenStream> {
1647 match value {
1648 SectionValue::Single(fragment) => {
1649 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1650 let (local_params, bindings) =
1651 build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1652 let body = render_runtime_fragment(fragment, &local_params)?;
1653 Ok(quote! {{ #( #bindings )* #body }})
1654 }
1655 SectionValue::Grouped(fragments) => build_section_runtime_action(
1656 &fragments[section_idx],
1657 0,
1658 &format!("{}_grouped_{}", prefix, section_idx),
1659 ),
1660 SectionValue::Match { expr, arms } => {
1661 let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1662 .iter()
1663 .enumerate()
1664 .map(|(arm_idx, arm)| {
1665 let pat = &arm.pat;
1666 let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1667 let body = build_section_runtime_action(
1668 &arm.value,
1669 section_idx,
1670 &format!("{}_{}", prefix, arm_idx),
1671 )?;
1672 let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1673 .into_iter()
1674 .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1675 .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1676 .collect();
1677 Ok::<TokenStream2, TokenStream>(quote! {
1678 #pat #guard_tokens => {
1679 #( #noop_refs )*
1680 #body
1681 }
1682 })
1683 })
1684 .collect();
1685 let arm_tokens = arm_tokens?;
1686 Ok(quote! {
1687 match #expr {
1688 #( #arm_tokens ),*
1689 }
1690 })
1691 }
1692 }
1693}
1694
1695fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1697 let mut names = Vec::new();
1698 let mut seen = HashSet::<String>::new();
1699
1700 for segment in segments {
1701 match segment {
1702 Segment::Text(text) => {
1703 for name in collect_used_param_names_in_sql(text) {
1704 if seen.insert(name.clone()) {
1705 names.push(name);
1706 }
1707 }
1708 }
1709 Segment::Batch { parts } => {
1710 for part in parts {
1711 if let TextPart::Param { name, .. } = part {
1712 if seen.insert(name.clone()) {
1713 names.push(name.clone());
1714 }
1715 }
1716 }
1717 }
1718 _ => {}
1719 }
1720 }
1721
1722 names
1723}
1724
1725fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1727 let mut names = Vec::new();
1728 let mut seen = HashSet::<String>::new();
1729 for part in parse_text_parts(sql) {
1730 if let TextPart::Param { name, .. } = part {
1731 if seen.insert(name.to_string()) {
1732 names.push(name);
1733 }
1734 }
1735 }
1736 names
1737}
1738
1739#[proc_macro]
1965#[allow(clippy::too_many_lines)]
1966pub fn sql_forge(input: TokenStream) -> TokenStream {
1967 let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1969 let SqlForgeInput {
1970 db,
1971 result,
1972 force_scalar,
1973 sql,
1974 params,
1975 sections,
1976 batch,
1977 } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1978 Ok(v) => v,
1979 Err(err) => return err.to_compile_error().into(),
1980 };
1981
1982 let db = match db {
1984 Some(db) => db,
1985 None => match resolve_db_from_env() {
1986 Ok(db) => db,
1987 Err(msg) => {
1988 return syn::Error::new(Span::call_site(), msg)
1989 .to_compile_error()
1990 .into();
1991 }
1992 },
1993 };
1994
1995 let use_dollar_params = uses_dollar_params(&db);
1996 let is_sqlite = if let syn::Type::Path(type_path) = &db {
1997 type_path
1998 .path
1999 .segments
2000 .last()
2001 .is_some_and(|s| s.ident == "Sqlite")
2002 } else {
2003 false
2004 };
2005 let list_count: usize = if is_sqlite { 1 } else { 3 };
2006
2007 let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
2011 ResultSpec::None => {
2012 vec![(None, None, None)]
2013 }
2014 ResultSpec::Single(ref model) => {
2015 let model_ty = (**model).clone();
2016 let scalar = if force_scalar {
2017 Some(model_ty.clone())
2018 } else {
2019 scalar_output_type(model.as_ref()).cloned()
2020 };
2021 vec![(None, Some(model_ty), scalar)]
2022 }
2023 ResultSpec::Group(ref cases) => {
2024 if force_scalar {
2025 return syn::Error::new(
2026 Span::call_site(),
2027 "sql_forge!: scalar mode is not supported for grouped result maps",
2028 )
2029 .to_compile_error()
2030 .into();
2031 }
2032
2033 let mut out = Vec::new();
2034 let mut seen = HashSet::new();
2035 for case in cases {
2036 let key = case.name.to_string();
2037 if !seen.insert(key.clone()) {
2038 return syn::Error::new(
2039 case.name.span(),
2040 "sql_forge!: duplicated key in result map",
2041 )
2042 .to_compile_error()
2043 .into();
2044 }
2045
2046 let model = case.model.clone();
2047 let scalar = if case.force_scalar {
2048 Some(model.clone())
2049 } else {
2050 scalar_output_type(&case.model).cloned()
2051 };
2052 out.push((Some(key), Some(model), scalar));
2053 }
2054 out
2055 }
2056 };
2057 let group_result_keys: Vec<String> = result_cases
2058 .iter()
2059 .filter_map(|(key, _, _)| key.as_ref().cloned())
2060 .collect();
2061 let is_grouped_result = !group_result_keys.is_empty();
2062 let sql_span = sql.span();
2063
2064 let segments = match sql.into_segments() {
2066 Ok(segments) => segments,
2067 Err(msg) => {
2068 return syn::Error::new(sql_span, msg).to_compile_error().into();
2069 }
2070 };
2071
2072 let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2073 match (&batch, has_batch_segment) {
2074 (None, true) => {
2075 return syn::Error::new(
2076 sql_span,
2077 "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2078 was provided"
2079 )
2080 .to_compile_error()
2081 .into();
2082 }
2083 (Some(_), false) => {
2084 return syn::Error::new(
2085 sql_span,
2086 "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2087 batch section",
2088 )
2089 .to_compile_error()
2090 .into();
2091 }
2092 _ => {}
2093 }
2094
2095 let used_param_names = collect_used_param_names(&segments);
2096
2097 let batch_param_names: std::collections::HashSet<String> = segments
2102 .iter()
2103 .filter_map(|s| {
2104 if let Segment::Batch { parts } = s {
2105 Some(parts.iter().filter_map(|p| {
2106 if let TextPart::Param { name, .. } = p {
2107 Some(name.clone())
2108 } else {
2109 None
2110 }
2111 }))
2112 } else {
2113 None
2114 }
2115 })
2116 .flatten()
2117 .collect();
2118 let top_level_used_names: Vec<String> = used_param_names
2119 .iter()
2120 .filter(|n| !batch_param_names.contains(*n))
2121 .cloned()
2122 .collect();
2123
2124 let (declared_params, validator_param_bindings) =
2126 match build_param_bindings(¶ms, &top_level_used_names, "top_level", true, true) {
2127 Ok(v) => v,
2128 Err(err) => return err,
2129 };
2130
2131 let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2132
2133 for assign in §ions {
2135 let SectionAssign { names, value } = assign;
2136
2137 let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2139 for (section_idx, name_ident) in names.iter().enumerate() {
2140 let name = name_ident.to_string();
2141 if runtime_section_actions.contains_key(&name) {
2142 return syn::Error::new(
2143 name_ident.span(),
2144 "sql_forge!: duplicated section mapping",
2145 )
2146 .to_compile_error()
2147 .into();
2148 }
2149 let action = match build_section_runtime_action(
2150 value,
2151 section_idx,
2152 &format!("section_{}", name),
2153 ) {
2154 Ok(action) => action,
2155 Err(err) => return err,
2156 };
2157 named_actions.push((name, action));
2158 }
2159
2160 if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2162 return syn::Error::new(names[0].span(), msg)
2163 .to_compile_error()
2164 .into();
2165 }
2166
2167 for (name, action) in named_actions {
2168 runtime_section_actions.insert(name, action);
2169 }
2170 }
2171
2172 let sql_section_names: std::collections::HashSet<&str> = segments
2173 .iter()
2174 .filter_map(|seg| {
2175 if let Segment::Section { name } = seg {
2176 Some(name.as_str())
2177 } else {
2178 None
2179 }
2180 })
2181 .collect();
2182 for name in runtime_section_actions.keys() {
2183 if !sql_section_names.contains(name.as_str()) {
2184 return syn::Error::new(
2185 sql_span,
2186 format!(
2187 "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2188 name, name,
2189 ),
2190 )
2191 .to_compile_error()
2192 .into();
2193 }
2194 }
2195
2196 let mut generated_query_defs = Vec::<TokenStream2>::new();
2198 let mut generated_query_values = Vec::<TokenStream2>::new();
2199 let mut group_field_defs = Vec::<TokenStream2>::new();
2200 let mut group_method_defs = Vec::<TokenStream2>::new();
2201 let mut group_field_idents = Vec::<syn::Ident>::new();
2202 let mut group_field_tys = Vec::<TokenStream2>::new();
2203 let mut group_trait_impls = Vec::<TokenStream2>::new();
2204
2205 let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2206
2207 for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2208 let suffix = result_key.as_deref().unwrap_or("single");
2209 let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2210 let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2211
2212 let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2213
2214 let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2215 for assign in §ions {
2216 let SectionAssign { names, value } = assign;
2217 let variants_by_section = match collect_section_variants_for_result(
2218 value.clone(),
2219 names.len(),
2220 result_key.as_deref(),
2221 ) {
2222 Ok(v) => v,
2223 Err(msg) => {
2224 return syn::Error::new(names[0].span(), msg)
2225 .to_compile_error()
2226 .into();
2227 }
2228 };
2229
2230 for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2231 section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2232 }
2233 }
2234
2235 let mut nmax = 1usize;
2236 for segment in &segments {
2237 if let Segment::Section { name } = segment {
2238 if let Some(variants) = section_variants_for_validation.get(name) {
2239 if variants.is_empty() {
2240 return syn::Error::new(
2241 sql_span,
2242 format!("sql_forge!: section {{#{}}} has no possible variants", name),
2243 )
2244 .to_compile_error()
2245 .into();
2246 }
2247 nmax = nmax.max(variants.len());
2248 } else {
2249 return syn::Error::new(
2250 sql_span,
2251 format!("sql_forge!: section {{#{}}} has no mapping", name),
2252 )
2253 .to_compile_error()
2254 .into();
2255 }
2256 }
2257 }
2258
2259 let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2260 for case_idx in 0..nmax {
2261 let mut sql_case = String::new();
2262 let mut case_setup = Vec::<TokenStream2>::new();
2263 let mut case_args = Vec::<TokenStream2>::new();
2264 let mut param_offset = 0usize;
2265 let mut arg_index = 0usize;
2266 let empty_params = HashMap::<String, syn::Ident>::new();
2267 let root_validator_context = ValidatorRenderContext {
2268 local_params: &empty_params,
2269 top_level_params: &declared_params,
2270 allow_top_level_fallback: true,
2271 use_dollar_params,
2272 sql_span,
2273 list_count,
2274 };
2275
2276 for segment in &segments {
2277 match segment {
2278 Segment::Text(text) => {
2279 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2280 text,
2281 &mut param_offset,
2282 &mut arg_index,
2283 &root_validator_context,
2284 ) {
2285 Ok(value) => value,
2286 Err(err) => return err,
2287 };
2288 sql_case.push_str(&chunk_sql);
2289 case_setup.extend(chunk_setup);
2290 case_args.extend(chunk_args);
2291 }
2292 Segment::Section { name } => {
2293 let Some(variants) = section_variants_for_validation.get(name) else {
2294 return syn::Error::new(
2295 sql_span,
2296 format!("sql_forge!: section {{#{}}} has no mapping", name),
2297 )
2298 .to_compile_error()
2299 .into();
2300 };
2301
2302 let fragment = &variants[case_idx % variants.len()];
2303 let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2304 let (local_params, bindings) = match build_param_bindings(
2305 &fragment.params,
2306 &used_param_names,
2307 &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2308 true,
2309 true,
2310 ) {
2311 Ok(value) => value,
2312 Err(err) => return err,
2313 };
2314 let section_validator_context = ValidatorRenderContext {
2315 local_params: &local_params,
2316 top_level_params: &declared_params,
2317 allow_top_level_fallback: false,
2318 use_dollar_params,
2319 sql_span: fragment.span,
2320 list_count,
2321 };
2322 let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2323 &fragment.sql,
2324 &mut param_offset,
2325 &mut arg_index,
2326 §ion_validator_context,
2327 ) {
2328 Ok(value) => value,
2329 Err(err) => return err,
2330 };
2331 sql_case.push_str(&chunk_sql);
2332 case_setup.extend(bindings);
2333 case_setup.extend(chunk_setup);
2334 case_args.extend(chunk_args);
2335 }
2336 Segment::Batch { parts } => {
2337 let mut first = true;
2338 for _ in 0..list_count {
2339 let sep = if first { "" } else { ", " };
2340 first = false;
2341 sql_case.push_str(sep);
2342 for tp in parts {
2343 match tp {
2344 TextPart::Lit(lit) => sql_case.push_str(lit),
2345 TextPart::Param { name, .. } => {
2346 if let Some(batch_expr) = &batch {
2347 let field_ident = format_ident!("{}", name);
2348 if use_dollar_params {
2349 param_offset += 1;
2350 write!(sql_case, "${}", param_offset).unwrap();
2351 } else {
2352 sql_case.push('?');
2353 }
2354 case_args.push(quote! { #batch_expr[0].#field_ident });
2355 } else if use_dollar_params {
2356 param_offset += 1;
2357 write!(sql_case, "${}", param_offset).unwrap();
2358 } else {
2359 sql_case.push('?');
2360 }
2361 }
2362 }
2363 }
2364 }
2365 }
2366 }
2367 }
2368
2369 validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2370 }
2371
2372 let mut validator_invocations = Vec::<TokenStream2>::new();
2373 for (sql_lit, case_setup, args) in &validator_cases {
2374 if model_opt.is_none() {
2375 if args.is_empty() {
2376 validator_invocations.push(quote! {
2377 {
2378 #( #case_setup )*
2379 let _ = sqlx::query_scalar!(
2380 #sql_lit,
2381 );
2382 }
2383 });
2384 } else {
2385 validator_invocations.push(quote! {
2386 {
2387 #( #case_setup )*
2388 let _ = sqlx::query_scalar!(
2389 #sql_lit,
2390 #( #args ),*
2391 );
2392 }
2393 });
2394 }
2395 } else if let Some(scalar_ty) = scalar_model_ty {
2396 if args.is_empty() {
2397 validator_invocations.push(quote! {
2398 {
2399 #( #case_setup )*
2400 let _ = sqlx::query_scalar!(
2401 #sql_lit,
2402 );
2403 }
2404 });
2405 } else {
2406 validator_invocations.push(quote! {
2407 {
2408 #( #case_setup )*
2409 let _ = sqlx::query_scalar!(
2410 #sql_lit,
2411 #( #args ),*
2412 );
2413 }
2414 });
2415 }
2416 let _ = scalar_ty;
2417 } else if args.is_empty() {
2418 validator_invocations.push(quote! {
2419 {
2420 #( #case_setup )*
2421 let _ = sqlx::query_as!(
2422 __EnhancedModel,
2423 #sql_lit,
2424 );
2425 }
2426 });
2427 } else {
2428 validator_invocations.push(quote! {
2429 {
2430 #( #case_setup )*
2431 let _ = sqlx::query_as!(
2432 __EnhancedModel,
2433 #sql_lit,
2434 #( #args ),*
2435 );
2436 }
2437 });
2438 }
2439 }
2440
2441 let model_alias = if let Some(model) = model_opt {
2442 if scalar_model_ty.is_none() {
2443 quote! { type __EnhancedModel = #model; }
2444 } else {
2445 quote! {}
2446 }
2447 } else {
2448 quote! {}
2449 };
2450 grouped_validator_invocations.push(quote! {
2451 {
2452 #( #flag_bindings )*
2453 #model_alias
2454 #( #validator_invocations )*
2455 }
2456 });
2457
2458 let (runtime_declared_params, runtime_param_bindings) =
2459 match build_param_bindings(¶ms, &used_param_names, "runtime", false, false) {
2460 Ok(v) => v,
2461 Err(err) => return err,
2462 };
2463
2464 let mut runtime_steps = Vec::<TokenStream2>::new();
2465 for (seg_idx, segment) in segments.iter().enumerate() {
2466 match segment {
2467 Segment::Text(text) => {
2468 for part in parse_text_parts(text) {
2469 match part {
2470 TextPart::Lit(lit) => {
2471 let lit = sanitize_runtime_sql_text(&lit);
2472 let lit_str = LitStr::new(&lit, sql_span);
2473 runtime_steps.push(quote! {
2474 __builder.push(#lit_str);
2475 });
2476 }
2477 TextPart::Param { name, is_list } => {
2478 let Some(local_ident) = runtime_declared_params.get(&name) else {
2479 return syn::Error::new(
2480 sql_span,
2481 format!("sql_forge!: parameter :{} has no mapping", name),
2482 )
2483 .to_compile_error()
2484 .into();
2485 };
2486
2487 if is_list {
2488 runtime_steps.push(quote! {
2489 let __enhanced_values = #local_ident;
2490 let mut __separated = __builder.separated(", ");
2491 for __value in __enhanced_values {
2492 __separated.push_bind(__value);
2493 }
2494 });
2495 } else {
2496 runtime_steps.push(quote! {
2497 __builder.push_bind(#local_ident);
2498 });
2499 }
2500 }
2501 }
2502 }
2503 }
2504 Segment::Section { name } => {
2505 let Some(section_action) = runtime_section_actions.get(name) else {
2506 let _ = seg_idx;
2507 return syn::Error::new(
2508 sql_span,
2509 format!("sql_forge!: section {{#{}}} has no mapping", name),
2510 )
2511 .to_compile_error()
2512 .into();
2513 };
2514 runtime_steps.push(quote! {
2515 #section_action
2516 });
2517 }
2518 Segment::Batch { parts } => {
2519 if let Some(batch_expr) = &batch {
2520 let mut body = Vec::<TokenStream2>::new();
2521 for part in parts {
2522 match part {
2523 TextPart::Lit(lit) => {
2524 let lit_str = LitStr::new(lit, sql_span);
2525 body.push(quote! {
2526 __builder.push(#lit_str);
2527 });
2528 }
2529 TextPart::Param { name, .. } => {
2530 let field_ident = format_ident!("{}", name);
2531 body.push(quote! {
2532 __builder.push_bind(__item.#field_ident);
2533 });
2534 }
2535 }
2536 }
2537 runtime_steps.push(quote! {
2538 {
2539 let mut __first = true;
2540 for __item in #batch_expr {
2541 if !__first {
2542 __builder.push(", ");
2543 }
2544 __first = false;
2545 #( #body )*
2546 }
2547 }
2548 });
2549 }
2550 }
2551 }
2552 }
2553
2554 let exec_methods = if model_opt.is_none() {
2555 quote! {
2556 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2557 where
2558 E: sqlx::Executor<'e, Database = #db>,
2559 {
2560 self.inner.build().execute(executor).await
2561 }
2562 }
2563 } else if let Some(scalar_ty) = scalar_model_ty {
2564 quote! {
2565 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2566 where
2567 E: sqlx::Executor<'e, Database = #db>,
2568 {
2569 self.inner
2570 .build_query_scalar::<#scalar_ty>()
2571 .fetch_all(executor)
2572 .await
2573 }
2574
2575 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2576 where
2577 E: sqlx::Executor<'e, Database = #db>,
2578 {
2579 self.inner
2580 .build_query_scalar::<#scalar_ty>()
2581 .fetch_one(executor)
2582 .await
2583 }
2584
2585 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2586 where
2587 E: sqlx::Executor<'e, Database = #db>,
2588 {
2589 self.inner
2590 .build_query_scalar::<#scalar_ty>()
2591 .fetch_optional(executor)
2592 .await
2593 }
2594
2595 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2596 where
2597 E: sqlx::Executor<'e, Database = #db>,
2598 {
2599 self.inner.build().execute(executor).await
2600 }
2601 }
2602 } else {
2603 let model = model_opt.as_ref().unwrap();
2604 quote! {
2605 async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2606 where
2607 E: sqlx::Executor<'e, Database = #db>,
2608 {
2609 self.inner.build_query_as::<#model>().fetch_all(executor).await
2610 }
2611
2612 async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2613 where
2614 E: sqlx::Executor<'e, Database = #db>,
2615 {
2616 self.inner.build_query_as::<#model>().fetch_one(executor).await
2617 }
2618
2619 async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2620 where
2621 E: sqlx::Executor<'e, Database = #db>,
2622 {
2623 self.inner
2624 .build_query_as::<#model>()
2625 .fetch_optional(executor)
2626 .await
2627 }
2628
2629 async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2630 where
2631 E: sqlx::Executor<'e, Database = #db>,
2632 {
2633 self.inner.build().execute(executor).await
2634 }
2635 }
2636 };
2637
2638 let final_type: TokenStream2 = if let Some(model) = model_opt {
2639 if let Some(scalar_ty) = scalar_model_ty {
2640 quote! { #scalar_ty }
2641 } else {
2642 quote! { #model }
2643 }
2644 } else {
2645 quote! {}
2646 };
2647 let trait_impl = if model_opt.is_none() {
2648 quote! {
2649 impl sql_forge::SqlForgeQueryExecute
2650 for #query_ident
2651 {
2652 type Db = #db;
2653
2654 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2655 where
2656 Self: Sized + 'e,
2657 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2658 #db: 'e,
2659 {
2660 #query_ident::execute(self, executor)
2661 }
2662 }
2663 }
2664 } else {
2665 quote! {
2666 impl sql_forge::SqlForgeQuery<#final_type>
2667 for #query_ident
2668 {
2669 type Db = #db;
2670
2671 fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2672 where
2673 Self: Sized + 'e,
2674 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2675 #db: 'e,
2676 {
2677 #query_ident::fetch_all(self, executor)
2678 }
2679
2680 fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2681 where
2682 Self: Sized + 'e,
2683 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2684 #db: 'e,
2685 {
2686 #query_ident::fetch_one(self, executor)
2687 }
2688
2689 fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2690 where
2691 Self: Sized + 'e,
2692 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2693 #db: 'e,
2694 {
2695 #query_ident::fetch_optional(self, executor)
2696 }
2697
2698 fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2699 where
2700 Self: Sized + 'e,
2701 E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2702 #db: 'e,
2703 {
2704 #query_ident::execute(self, executor)
2705 }
2706 }
2707 }
2708 };
2709
2710 generated_query_defs.push(quote! {
2711 struct #query_ident {
2712 inner: sqlx::QueryBuilder<#db>,
2713 }
2714
2715 impl #query_ident {
2716 #exec_methods
2717 }
2718
2719 #trait_impl
2720 });
2721
2722 generated_query_values.push(quote! {
2723 #( #runtime_param_bindings )*
2724 #( #flag_bindings )*
2725 let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2726 #( #runtime_steps )*
2727 let #query_value_ident = #query_ident { inner: __builder };
2728 });
2729
2730 if let Some(key) = result_key {
2731 let method_ident = format_ident!("{}", key);
2732 group_field_defs.push(quote! {
2733 #method_ident: #query_ident
2734 });
2735 group_field_tys.push(quote! { #query_ident });
2736 group_method_defs.push(quote! {
2737 pub fn #method_ident(self) -> #query_ident {
2738 self.#method_ident
2739 }
2740 });
2741
2742 let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2743 group_trait_impls.push(quote! {
2744 struct #key_ty_ident;
2745
2746 impl sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup {
2747 type Query = #query_ident;
2748
2749 fn get(self, _: #key_ty_ident) -> Self::Query {
2750 self.#method_ident
2751 }
2752 }
2753 });
2754 group_field_idents.push(method_ident);
2755 }
2756 }
2757
2758 let validator_tokens = quote! {
2760 let _sql_forge_validator = || {
2761 #( #validator_param_bindings )*
2762 #( #grouped_validator_invocations )*
2763 };
2764 };
2765
2766 if !is_grouped_result {
2767 let single_query_value_ident = format_ident!("__sql_forge_value_single");
2768 return quote! {
2769 {
2770 #validator_tokens
2771 #( #generated_query_defs )*
2772 #( #generated_query_values )*
2773 #single_query_value_ident
2774 }
2775 }
2776 .into();
2777 }
2778
2779 let group_field_inits: Vec<TokenStream2> = result_cases
2780 .iter()
2781 .filter_map(|(key, _, _)| key.as_ref())
2782 .map(|key| {
2783 let method_ident = format_ident!("{}", key);
2784 let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2785 quote! { #method_ident: #query_value_ident }
2786 })
2787 .collect();
2788
2789 quote! {
2790 {
2791 #validator_tokens
2792
2793 #( #generated_query_defs )*
2794 #( #generated_query_values )*
2795
2796 struct __SqlForgeQueryGroup {
2797 #( #group_field_defs, )*
2798 }
2799
2800 impl __SqlForgeQueryGroup {
2801 #( #group_method_defs )*
2802
2803 pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2804 ( #( self.#group_field_idents ),* )
2805 }
2806 }
2807
2808 impl sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup {
2809 type Db = #db;
2810 }
2811
2812 #( #group_trait_impls )*
2813
2814 __SqlForgeQueryGroup {
2815 #( #group_field_inits, )*
2816 }
2817 }
2818 }
2819 .into()
2820}
2821
2822#[proc_macro]
2836pub fn db_type(input: TokenStream) -> TokenStream {
2837 if !input.is_empty() {
2838 return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2839 .to_compile_error()
2840 .into();
2841 }
2842
2843 match resolve_db_from_env() {
2844 Ok(db) => quote! { #db }.into(),
2845 Err(msg) => syn::Error::new(Span::call_site(), msg)
2846 .to_compile_error()
2847 .into(),
2848 }
2849}
2850
2851#[proc_macro_attribute]
2866pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2867 let input: ItemStruct = match syn::parse(item) {
2868 Ok(v) => v,
2869 Err(err) => return err.to_compile_error().into(),
2870 };
2871
2872 let struct_name = &input.ident;
2873 let inner_type = match &input.fields {
2874 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2875 _ => {
2876 return syn::Error::new(
2877 input.span(),
2878 "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2879 )
2880 .to_compile_error()
2881 .into();
2882 }
2883 };
2884
2885 let attrs = input.attrs;
2886 let generics = &input.generics;
2887 let vis = &input.vis;
2888 let struct_token = input.struct_token;
2889 let semi_token = input.semi_token;
2890 let fields = &input.fields;
2891
2892 let expanded = quote! {
2893 #( #attrs )*
2894 #[derive(sqlx::Type)]
2895 #[sqlx(transparent)]
2896 #vis #struct_token #struct_name #generics #fields #semi_token
2897
2898 impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2899 fn sql_forge_validator_value(&self) -> #inner_type {
2900 self.0.clone()
2901 }
2902 }
2903 };
2904
2905 expanded.into()
2906}