1#![forbid(unsafe_code)]
2
3use std::ops::Deref;
4use std::path::PathBuf;
5
6use ariadne::{Color, Label, Report, ReportKind, Source};
7use once_cell::sync::Lazy;
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::{format_ident, quote, quote_spanned};
11use sql_type::schema::{parse_schemas, Schemas};
12use sql_type::{type_statement, Issue, SQLArguments, SQLDialect, SelectTypeColumn, TypeOptions};
13use syn::spanned::Spanned;
14use syn::{parse::Parse, punctuated::Punctuated, Expr, Ident, LitStr, Token};
15
16static SCHEMA_PATH: Lazy<PathBuf> = Lazy::new(|| {
17 let mut schema_path: PathBuf = std::env::var("CARGO_MANIFEST_DIR")
18 .expect("`CARGO_schema_path` must be set")
19 .into();
20
21 schema_path.push("sqlx-type-schema.sql");
22
23 if !schema_path.exists() {
24 use serde::Deserialize;
25 use std::process::Command;
26
27 let cargo = std::env::var("CARGO").expect("`CARGO` must be set");
28 schema_path.pop();
29
30 let output = Command::new(cargo)
31 .args(["metadata", "--format-version=1"])
32 .current_dir(&schema_path)
33 .env_remove("__CARGO_FIX_PLZ")
34 .output()
35 .expect("Could not fetch metadata");
36
37 #[derive(Deserialize)]
38 struct CargoMetadata {
39 workspace_root: PathBuf,
40 }
41
42 let metadata: CargoMetadata =
43 serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output");
44
45 schema_path = metadata.workspace_root;
46 schema_path.push("sqlx-type-schema.sql");
47 }
48 if !schema_path.exists() {
49 panic!("Unable to locate sqlx-type-schema.sql");
50 }
51 schema_path
52});
53
54static SCHEMA_SRC: Lazy<String> =
57 Lazy::new(|| match std::fs::read_to_string(SCHEMA_PATH.as_path()) {
58 Ok(v) => v,
59 Err(e) => panic!(
60 "Unable to read schema from {:?}: {}",
61 SCHEMA_PATH.as_path(),
62 e
63 ),
64 });
65
66fn issue_to_report(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
67 let mut builder = Report::build(
68 match issue.level {
69 sql_type::Level::Warning => ReportKind::Warning,
70 sql_type::Level::Error => ReportKind::Error,
71 },
72 issue.span.clone(),
73 )
74 .with_config(ariadne::Config::default().with_color(false))
75 .with_label(
76 Label::new(issue.span)
77 .with_order(-1)
78 .with_priority(-1)
79 .with_message(issue.message),
80 );
81 for frag in issue.fragments {
82 builder = builder.with_label(Label::new(frag.span).with_message(frag.message));
83 }
84 builder.finish()
85}
86
87fn issue_to_report_color(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
88 let mut builder = Report::build(
89 match issue.level {
90 sql_type::Level::Warning => ReportKind::Warning,
91 sql_type::Level::Error => ReportKind::Error,
92 },
93 issue.span.clone(),
94 )
95 .with_label(
96 Label::new(issue.span)
97 .with_color(match issue.level {
98 sql_type::Level::Warning => Color::Yellow,
99 sql_type::Level::Error => Color::Red,
100 })
101 .with_order(-1)
102 .with_priority(-1)
103 .with_message(issue.message),
104 );
105 for frag in issue.fragments {
106 builder = builder.with_label(
107 Label::new(frag.span)
108 .with_color(Color::Blue)
109 .with_message(frag.message),
110 );
111 }
112 builder.finish()
113}
114
115struct NamedSource<'a>(&'a str, Source<&'a str>);
116
117impl<'a> ariadne::Cache<()> for &NamedSource<'a> {
118 type Storage = &'a str;
119
120 fn display<'b>(&self, _: &'b ()) -> Option<impl std::fmt::Display + 'b> {
121 Some(self.0.to_string())
122 }
123
124 fn fetch(&mut self, _: &()) -> Result<&Source<Self::Storage>, impl std::fmt::Debug> {
125 Ok::<_, ()>(&self.1)
126 }
127}
128
129static SCHEMAS: Lazy<(Schemas, SQLDialect)> = Lazy::new(|| {
130 let schema_src = SCHEMA_SRC.as_str();
131 let dialect = if let Some(first_line) = schema_src.lines().next() {
132 if first_line.contains("sql-product: postgres") {
133 SQLDialect::PostgreSQL
134 } else if first_line.contains("sql-product: sqlite") {
135 SQLDialect::Sqlite
136 } else {
137 SQLDialect::MariaDB
138 }
139 } else {
140 SQLDialect::MariaDB
141 };
142
143 let options = TypeOptions::new().dialect(dialect.clone());
144 let mut issues = sql_type::Issues::new(schema_src);
145 let schemas = parse_schemas(schema_src, &mut issues, &options);
146 if !issues.is_ok() {
147 let source = NamedSource("sqlx-type-schema.sql", Source::from(schema_src));
148 let mut err = false;
149 for issue in issues.into_vec() {
150 if issue.level == sql_type::Level::Error {
151 err = true;
152 }
153 let r = issue_to_report_color(issue);
154 r.eprint(&source).unwrap();
155 }
156 if err {
157 panic!("Errors processing sqlx-type-schema.sql");
158 }
159 }
160 (schemas, dialect)
161});
162
163fn quote_args(
164 errors: &mut Vec<proc_macro2::TokenStream>,
165 query: &str,
166 last_span: Span,
167 args: &[Expr],
168 arguments: &[(sql_type::ArgumentKey<'_>, sql_type::FullType)],
169 dialect: &SQLDialect,
170) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
171 let cls = match dialect {
172 SQLDialect::MariaDB => quote!(sqlx::mysql::MySql),
173 SQLDialect::Sqlite => quote!(sqlx::sqlite::Sqlite),
174 SQLDialect::PostgreSQL => quote!(sqlx::postgres::Postgres),
175 };
176
177 let mut at = Vec::new();
178 let inv = sql_type::FullType::invalid();
179 for (k, v) in arguments {
180 match k {
181 sql_type::ArgumentKey::Index(i) => {
182 while at.len() <= *i {
183 at.push(&inv);
184 }
185 at[*i] = v;
186 }
187 sql_type::ArgumentKey::Identifier(_) => {
188 errors.push(
189 syn::Error::new(last_span.span(), "Named arguments not supported")
190 .to_compile_error(),
191 );
192 }
193 }
194 }
195
196 if at.len() > args.len() {
197 errors.push(
198 syn::Error::new(
199 last_span,
200 format!("Expected {} additional arguments", at.len() - args.len()),
201 )
202 .to_compile_error(),
203 );
204 }
205
206 if let Some(args) = args.get(at.len()..) {
207 for arg in args {
208 errors.push(syn::Error::new(arg.span(), "unexpected argument").to_compile_error());
209 }
210 }
211
212 let arg_names = (0..args.len())
213 .map(|i| format_ident!("arg{}", i))
214 .collect::<Vec<_>>();
215
216 let mut arg_bindings = Vec::new();
217 let mut arg_add = Vec::new();
218
219 let mut list_lengths = Vec::new();
220
221 for ((qa, ta), name) in args.iter().zip(at).zip(&arg_names) {
222 let mut t = match ta.t {
223 sql_type::Type::U8 => quote! {u8},
224 sql_type::Type::I8 => quote! {i8},
225 sql_type::Type::U16 => quote! {u16},
226 sql_type::Type::I16 => quote! {i16},
227 sql_type::Type::U32 => quote! {u32},
228 sql_type::Type::I32 => quote! {i32},
229 sql_type::Type::U64 => quote! {u64},
230 sql_type::Type::I64 => quote! {i64},
231 sql_type::Type::Base(sql_type::BaseType::Any) => quote! {sqlx_type::Any},
232 sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
233 sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {&[u8]},
234 sql_type::Type::Base(sql_type::BaseType::Date) => quote! {sqlx_type::Date},
235 sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {sqlx_type::DateTime},
236 sql_type::Type::Base(sql_type::BaseType::Float) => quote! {sqlx_type::Float},
237 sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {sqlx_type::Integer},
238 sql_type::Type::Base(sql_type::BaseType::String) => quote! {&str},
239 sql_type::Type::Base(sql_type::BaseType::Time) => quote! {sqlx_type::Time},
240 sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("time_interval"),
241 sql_type::Type::Base(sql_type::BaseType::TimeStamp) => quote! {sqlx_type::Timestamp},
242 sql_type::Type::Null => todo!("null"),
243 sql_type::Type::Invalid => quote! {std::convert::Infallible},
244 sql_type::Type::Enum(_) => quote! {&str},
245 sql_type::Type::Set(_) => quote! {&str},
246 sql_type::Type::Args(_, _) => todo!("args"),
247 sql_type::Type::F32 => quote! {f32},
248 sql_type::Type::F64 => quote! {f64},
249 sql_type::Type::JSON => quote! {sqlx_type::Any},
250 };
251 if !ta.not_null {
252 t = quote! {Option<#t>}
253 }
254 let span = qa.span();
255 if ta.list_hack {
256 list_lengths.push(quote!(#name.len()));
257 arg_bindings.push(quote_spanned! {span=>
258 let #name = &(#qa);
259 args_count += #name.len();
260 for v in #name.iter() {
261 size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(v);
262 }
263 if false {
264 sqlx_type::check_arg_list_hack::<#t, _>(#name);
265 ::std::panic!();
266 }
267 });
268 arg_add.push(quote!(
269 for v in #name.iter() {
270 e = e.and_then(|()| query_args.add(v));
271 }
272 ));
273 } else {
274 arg_bindings.push(quote_spanned! {span=>
275 let #name = &(#qa);
276 args_count += 1;
277 size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(#name);
278 if false {
279 sqlx_type::check_arg::<#t, _>(#name);
280 ::std::panic!();
281 }
282 });
283 arg_add.push(quote!(e = e.and_then(|()| query_args.add(#name));));
284 }
285 }
286
287 let query = if list_lengths.is_empty() {
288 quote!(#query)
289 } else {
290 quote!(
291 &sqlx_type::convert_list_query(#query, &[#(#list_lengths),*])
292 )
293 };
294
295 (
296 quote! {
297 let mut size_hints = 0;
298 let mut args_count = 0;
299 #(#arg_bindings)*
300
301 let mut query_args = <#cls as ::sqlx::database::Database>::Arguments::default();
302 query_args.reserve(args_count, size_hints);
303 let mut e = Ok(());
304 #(#arg_add)*
305 let query_args = e.and_then(|()| Ok(query_args));
306 },
307 query,
308 )
309}
310
311fn issues_to_errors(issues: Vec<Issue>, source: &str, span: Span) -> Vec<proc_macro2::TokenStream> {
312 if !issues.is_empty() {
313 let source = NamedSource("", Source::from(source));
314 let mut err = false;
315 let mut out = Vec::new();
316 for issue in issues {
317 if issue.level == sql_type::Level::Error {
318 err = true;
319 }
320 let r = issue_to_report(issue);
321 r.write(&source, &mut out).unwrap();
322 }
323 if err {
324 return vec![syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error()];
325 }
326 }
327 Vec::new()
328}
329
330fn construct_row(
331 columns: &[SelectTypeColumn],
332) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
333 let mut row_members = Vec::new();
334 let mut row_construct = Vec::new();
335 for (i, c) in columns.iter().enumerate() {
336 let mut t = match c.type_.t {
337 sql_type::Type::U8 => quote! {u8},
338 sql_type::Type::I8 => quote! {i8},
339 sql_type::Type::U16 => quote! {u16},
340 sql_type::Type::I16 => quote! {i16},
341 sql_type::Type::U32 => quote! {u32},
342 sql_type::Type::I32 => quote! {i32},
343 sql_type::Type::U64 => quote! {u64},
344 sql_type::Type::I64 => quote! {i64},
345 sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
346 sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
347 sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
348 sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
349 sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
350 sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
351 sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
352 sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
353 sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
354 sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("from_time_interval"),
355 sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
356 quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
357 }
358 sql_type::Type::Null => todo!("from_null"),
359 sql_type::Type::Invalid => quote! {i64},
360 sql_type::Type::Enum(_) => quote! {String},
361 sql_type::Type::Set(_) => quote! {String},
362 sql_type::Type::Args(_, _) => todo!("from_args"),
363 sql_type::Type::F32 => quote! {f32},
364 sql_type::Type::F64 => quote! {f64},
365 sql_type::Type::JSON => quote! {String},
366 };
367 let name = match &c.name {
368 Some(v) => v,
369 None => continue,
370 };
371
372 let ident = String::from("r#") + name.value;
373 let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
374 ident
375 } else {
376 continue;
379 };
380
381 if !c.type_.not_null {
382 t = quote! {Option<#t>};
383 }
384 row_members.push(quote! {
385 #ident : #t
386 });
387 row_construct.push(quote! {
388 #ident: sqlx::Row::get(&row, #i)
389 });
390 }
391 (row_members, row_construct)
392}
393
394struct Query {
395 query: String,
396 query_span: Span,
397 args: Vec<Expr>,
398 last_span: Span,
399}
400
401impl Parse for Query {
402 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
403 let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
404 let query: String = query_.iter().map(LitStr::value).collect();
405 let query_span = query_.span();
406 let mut last_span = query_span;
407 let mut args = Vec::new();
408 while !input.is_empty() {
409 let _ = input.parse::<syn::token::Comma>()?;
410 if input.is_empty() {
411 break;
412 }
413 let arg = input.parse::<Expr>()?;
414 last_span = arg.span();
415 args.push(arg);
416 }
417 Ok(Self {
418 query,
419 query_span,
420 args,
421 last_span,
422 })
423 }
424}
425
426#[proc_macro]
430pub fn query(input: TokenStream) -> TokenStream {
431 let query = syn::parse_macro_input!(input as Query);
432 let (schemas, dialect) = SCHEMAS.deref();
433 let options = TypeOptions::new()
434 .dialect(dialect.clone())
435 .arguments(match &dialect {
436 SQLDialect::MariaDB => SQLArguments::QuestionMark,
437 SQLDialect::Sqlite => SQLArguments::QuestionMark,
438 SQLDialect::PostgreSQL => SQLArguments::Dollar,
439 })
440 .list_hack(true);
441 let mut issues = sql_type::Issues::new(&query.query);
442 let stmt = type_statement(schemas, &query.query, &mut issues, &options);
443 let sp = SCHEMA_PATH.as_path().to_str().unwrap();
444
445 let mut errors = issues_to_errors(issues.into_vec(), &query.query, query.query_span);
446 match &stmt {
447 sql_type::StatementType::Select { columns, arguments } => {
448 let (args_tokens, q) = quote_args(
449 &mut errors,
450 &query.query,
451 query.last_span,
452 &query.args,
453 arguments,
454 dialect,
455 );
456 let (row_members, row_construct) = construct_row(columns);
457 let s = quote! { {
458 use ::sqlx::Arguments as _;
459 let _ = std::include_bytes!(#sp);
460 #(#errors; )*
461 #args_tokens;
462
463 struct Row {
464 #(#row_members),*
465 };
466 sqlx::__query_with_result(#q, query_args).map(|row|
467 Row{
468 #(#row_construct),*
469 }
470 )
471 }};
472 s.into()
473 }
474 sql_type::StatementType::Delete {
475 arguments,
476 returning,
477 } => {
478 let (args_tokens, q) = quote_args(
479 &mut errors,
480 &query.query,
481 query.last_span,
482 &query.args,
483 arguments,
484 dialect,
485 );
486 let s = match returning.as_ref() {
487 Some(returning) => {
488 let (row_members, row_construct) = construct_row(returning);
489 quote! { {
490 use ::sqlx::Arguments as _;
491 let _ = std::include_bytes!(#sp);
492 #(#errors; )*
493 #args_tokens
494
495 struct Row {
496 #(#row_members),*
497 };
498 sqlx::__query_with_result(#q, query_args).map(|row|
499 Row{
500 #(#row_construct),*
501 }
502 )
503 }}
504 }
505 None => quote! { {
506 use ::sqlx::Arguments as _;
507 #(#errors; )*
508 #args_tokens
509 sqlx::__query_with_result(#q, query_args)
510 }
511 },
512 };
513 s.into()
514 }
515 sql_type::StatementType::Insert {
516 arguments,
517 returning,
518 ..
519 } => {
520 let (args_tokens, q) = quote_args(
521 &mut errors,
522 &query.query,
523 query.last_span,
524 &query.args,
525 arguments,
526 dialect,
527 );
528 let s = match returning.as_ref() {
529 Some(returning) => {
530 let (row_members, row_construct) = construct_row(returning);
531 quote! { {
532 use ::sqlx::Arguments as _;
533 let _ = std::include_bytes!(#sp);
534 #(#errors; )*
535 #args_tokens
536
537 struct Row {
538 #(#row_members),*
539 };
540 sqlx::__query_with_result(#q, query_args).map(|row|
541 Row{
542 #(#row_construct),*
543 }
544 )
545 }}
546 }
547 None => quote! { {
548 use ::sqlx::Arguments as _;
549 #(#errors; )*
550 #args_tokens
551 sqlx::__query_with_result(#q, query_args)
552 }
553 },
554 };
555 s.into()
556 }
557 sql_type::StatementType::Update {
558 arguments,
559 returning,
560 } => {
561 let (args_tokens, q) = quote_args(
562 &mut errors,
563 &query.query,
564 query.last_span,
565 &query.args,
566 arguments,
567 dialect,
568 );
569
570 let s = match returning.as_ref() {
571 Some(returning) => {
572 let (row_members, row_construct) = construct_row(returning);
573 quote! { {
574 use ::sqlx::Arguments as _;
575 let _ = std::include_bytes!(#sp);
576 #(#errors; )*
577 #args_tokens
578
579 struct Row {
580 #(#row_members),*
581 };
582 sqlx::__query_with_result(#q, query_args).map(|row|
583 Row{
584 #(#row_construct),*
585 }
586 )
587 }}
588 }
589 None => quote! { {
590 use ::sqlx::Arguments as _;
591 #(#errors; )*
592 #args_tokens
593 sqlx::__query_with_result(#q, query_args)
594 }
595 },
596 };
597 s.into()
598 }
599 sql_type::StatementType::Replace {
600 arguments,
601 returning,
602 } => {
603 let (args_tokens, q) = quote_args(
604 &mut errors,
605 &query.query,
606 query.last_span,
607 &query.args,
608 arguments,
609 dialect,
610 );
611 let s = match returning.as_ref() {
612 Some(returning) => {
613 let (row_members, row_construct) = construct_row(returning);
614 quote! { {
615 use ::sqlx::Arguments as _;
616 let _ = std::include_bytes!(#sp);
617 #(#errors; )*
618 #args_tokens
619
620 struct Row {
621 #(#row_members),*
622 };
623 sqlx::__query_with_result(#q, query_args).map(|row|
624 Row{
625 #(#row_construct),*
626 }
627 )
628 }}
629 }
630 None => quote! { {
631 use ::sqlx::Arguments as _;
632 #(#errors; )*
633 #args_tokens
634 sqlx::__query_with_result(#q, query_args)
635 }
636 },
637 };
638 s.into()
639 }
640 sql_type::StatementType::Invalid => {
641 let s = quote! { {
642 #(#errors; )*;
643 todo!("Invalid")
644 }};
645 s.into()
646 }
647 }
648}
649
650fn construct_row2(columns: &[SelectTypeColumn]) -> Vec<proc_macro2::TokenStream> {
651 let mut row_construct = Vec::new();
652 for (i, c) in columns.iter().enumerate() {
653 let mut t = match c.type_.t {
654 sql_type::Type::U8 => quote! {u8},
655 sql_type::Type::I8 => quote! {i8},
656 sql_type::Type::U16 => quote! {u16},
657 sql_type::Type::I16 => quote! {i16},
658 sql_type::Type::U32 => quote! {u32},
659 sql_type::Type::I32 => quote! {i32},
660 sql_type::Type::U64 => quote! {u64},
661 sql_type::Type::I64 => quote! {i64},
662 sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
663 sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
664 sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
665 sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
666 sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
667 sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
668 sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
669 sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
670 sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
671 sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("from_time_interval"),
672 sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
673 quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
674 }
675 sql_type::Type::Null => todo!("from_null"),
676 sql_type::Type::Invalid => quote! {i64},
677 sql_type::Type::Enum(_) => quote! {String},
678 sql_type::Type::Set(_) => quote! {String},
679 sql_type::Type::Args(_, _) => todo!("from_args"),
680 sql_type::Type::F32 => quote! {f32},
681 sql_type::Type::F64 => quote! {f64},
682 sql_type::Type::JSON => quote! {String},
683 };
684 let name = match &c.name {
685 Some(v) => v,
686 None => continue,
687 };
688
689 let ident = String::from("r#") + name.value;
690 let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
691 ident
692 } else {
693 continue;
696 };
697
698 if !c.type_.not_null {
699 t = quote! {Option<#t>};
700 }
701 row_construct.push(quote! {
702 #ident: sqlx_type::arg_out::<#t, _, #i>(sqlx::Row::get(&row, #i))
703 });
704 }
705 row_construct
706}
707
708struct QueryAs {
709 as_: Ident,
710 query: String,
711 query_span: Span,
712 args: Vec<Expr>,
713 last_span: Span,
714}
715
716impl Parse for QueryAs {
717 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
718 let as_ = input.parse::<Ident>()?;
719 let _ = input.parse::<syn::token::Comma>()?;
720
721 let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
722 let query: String = query_.iter().map(LitStr::value).collect();
723 let query_span = query_.span();
724
725 let mut last_span = query_span;
726 let mut args = Vec::new();
727 while !input.is_empty() {
728 let _ = input.parse::<syn::token::Comma>()?;
729 if input.is_empty() {
730 break;
731 }
732 let arg = input.parse::<Expr>()?;
733 last_span = arg.span();
734 args.push(arg);
735 }
736 Ok(Self {
737 as_,
738 query,
739 query_span,
740 args,
741 last_span,
742 })
743 }
744}
745
746#[proc_macro]
750pub fn query_as(input: TokenStream) -> TokenStream {
751 let query_as = syn::parse_macro_input!(input as QueryAs);
752 let (schemas, dialect) = SCHEMAS.deref();
753 let options = TypeOptions::new()
754 .dialect(dialect.clone())
755 .arguments(match &dialect {
756 SQLDialect::MariaDB => SQLArguments::QuestionMark,
757 SQLDialect::Sqlite => SQLArguments::QuestionMark,
758 SQLDialect::PostgreSQL => SQLArguments::Dollar,
759 })
760 .list_hack(true);
761 let mut issues = sql_type::Issues::new(&query_as.query);
762 let stmt = type_statement(schemas, &query_as.query, &mut issues, &options);
763
764 let mut errors = issues_to_errors(issues.into_vec(), &query_as.query, query_as.query_span);
765 match &stmt {
766 sql_type::StatementType::Select { columns, arguments } => {
767 let (args_tokens, q) = quote_args(
768 &mut errors,
769 &query_as.query,
770 query_as.last_span,
771 &query_as.args,
772 arguments,
773 dialect,
774 );
775
776 let row_construct = construct_row2(columns);
777 let row = query_as.as_;
778 let s = quote! { {
779 use ::sqlx::Arguments as _;
780 #(#errors; )*
781 #args_tokens
782 sqlx::__query_with_result(#q, query_args).map(|row|
783 #row{
784 #(#row_construct),*
785 }
786 )
787 }};
788 s.into()
790 }
791 sql_type::StatementType::Delete { .. } => {
792 errors.push(
793 syn::Error::new(query_as.query_span, "DELETE not support in query_as")
794 .to_compile_error(),
795 );
796 quote! { {
797 #(#errors; )*
798 todo!("delete")
799 }}
800 .into()
801 }
802 sql_type::StatementType::Insert {
803 returning: None, ..
804 } => {
805 errors.push(
806 syn::Error::new(
807 query_as.query_span,
808 "INSERT without RETURNING not support in query_as",
809 )
810 .to_compile_error(),
811 );
812 quote! { {
813 #(#errors; )*
814 todo!("insert")
815 }}
816 .into()
817 }
818 sql_type::StatementType::Insert {
819 arguments,
820 returning: Some(returning),
821 ..
822 } => {
823 let (args_tokens, q) = quote_args(
824 &mut errors,
825 &query_as.query,
826 query_as.last_span,
827 &query_as.args,
828 arguments,
829 dialect,
830 );
831
832 let row_construct = construct_row2(returning);
833 let row = query_as.as_;
834 let s = quote! { {
835 use ::sqlx::Arguments as _;
836 #(#errors; )*
837 #args_tokens
838 sqlx::__query_with_result(#q, query_args).map(|row|
839 #row{
840 #(#row_construct),*
841 }
842 )
843 }};
844 s.into()
845 }
846 sql_type::StatementType::Update { .. } => {
847 errors.push(
848 syn::Error::new(query_as.query_span, "UPDATE not support in query_as")
849 .to_compile_error(),
850 );
851 quote! { {
852 #(#errors; )*
853 todo!("update")
854 }}
855 .into()
856 }
857 sql_type::StatementType::Replace {
858 returning: None, ..
859 } => {
860 errors.push(
861 syn::Error::new(
862 query_as.query_span,
863 "REPLACE without RETURNING not support in query_as",
864 )
865 .to_compile_error(),
866 );
867 quote! { {
868 #(#errors; )*
869 todo!("replace")
870 }}
871 .into()
872 }
873 sql_type::StatementType::Replace {
874 arguments,
875 returning: Some(returning),
876 ..
877 } => {
878 let (args_tokens, q) = quote_args(
879 &mut errors,
880 &query_as.query,
881 query_as.last_span,
882 &query_as.args,
883 arguments,
884 dialect,
885 );
886
887 let row_construct = construct_row2(returning);
888 let row = query_as.as_;
889 let s = quote! { {
890 use ::sqlx::Arguments as _;
891 #(#errors; )*
892 #args_tokens
893 sqlx::__query_with_result(#q, query_args).map(|row|
894 #row{
895 #(#row_construct),*
896 }
897 )
898 }};
899 s.into()
900 }
901 sql_type::StatementType::Invalid => quote! { {
902 #(#errors; )*;
903 todo!("invalid")
904 }}
905 .into(),
906 }
907}