sqlx_type_macro/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::ops::Deref;
4use std::path::PathBuf;
5
6use ariadne::{Color, Label, Report, ReportKind, Source};
7use once_cell::sync::Lazy;
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::{format_ident, quote, quote_spanned};
11use sql_type::schema::{parse_schemas, Schemas};
12use sql_type::{type_statement, Issue, SQLArguments, SQLDialect, SelectTypeColumn, TypeOptions};
13use syn::spanned::Spanned;
14use syn::{parse::Parse, punctuated::Punctuated, Expr, Ident, LitStr, Token};
15
16static SCHEMA_PATH: Lazy<PathBuf> = Lazy::new(|| {
17    let mut schema_path: PathBuf = std::env::var("CARGO_MANIFEST_DIR")
18        .expect("`CARGO_schema_path` must be set")
19        .into();
20
21    schema_path.push("sqlx-type-schema.sql");
22
23    if !schema_path.exists() {
24        use serde::Deserialize;
25        use std::process::Command;
26
27        let cargo = std::env::var("CARGO").expect("`CARGO` must be set");
28        schema_path.pop();
29
30        let output = Command::new(cargo)
31            .args(["metadata", "--format-version=1"])
32            .current_dir(&schema_path)
33            .env_remove("__CARGO_FIX_PLZ")
34            .output()
35            .expect("Could not fetch metadata");
36
37        #[derive(Deserialize)]
38        struct CargoMetadata {
39            workspace_root: PathBuf,
40        }
41
42        let metadata: CargoMetadata =
43            serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output");
44
45        schema_path = metadata.workspace_root;
46        schema_path.push("sqlx-type-schema.sql");
47    }
48    if !schema_path.exists() {
49        panic!("Unable to locate sqlx-type-schema.sql");
50    }
51    schema_path
52});
53
54// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
55// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
56static SCHEMA_SRC: Lazy<String> =
57    Lazy::new(|| match std::fs::read_to_string(SCHEMA_PATH.as_path()) {
58        Ok(v) => v,
59        Err(e) => panic!(
60            "Unable to read schema from {:?}: {}",
61            SCHEMA_PATH.as_path(),
62            e
63        ),
64    });
65
66fn issue_to_report(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
67    let mut builder = Report::build(
68        match issue.level {
69            sql_type::Level::Warning => ReportKind::Warning,
70            sql_type::Level::Error => ReportKind::Error,
71        },
72        issue.span.clone(),
73    )
74    .with_config(ariadne::Config::default().with_color(false))
75    .with_label(
76        Label::new(issue.span)
77            .with_order(-1)
78            .with_priority(-1)
79            .with_message(issue.message),
80    );
81    for frag in issue.fragments {
82        builder = builder.with_label(Label::new(frag.span).with_message(frag.message));
83    }
84    builder.finish()
85}
86
87fn issue_to_report_color(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
88    let mut builder = Report::build(
89        match issue.level {
90            sql_type::Level::Warning => ReportKind::Warning,
91            sql_type::Level::Error => ReportKind::Error,
92        },
93        issue.span.clone(),
94    )
95    .with_label(
96        Label::new(issue.span)
97            .with_color(match issue.level {
98                sql_type::Level::Warning => Color::Yellow,
99                sql_type::Level::Error => Color::Red,
100            })
101            .with_order(-1)
102            .with_priority(-1)
103            .with_message(issue.message),
104    );
105    for frag in issue.fragments {
106        builder = builder.with_label(
107            Label::new(frag.span)
108                .with_color(Color::Blue)
109                .with_message(frag.message),
110        );
111    }
112    builder.finish()
113}
114
115struct NamedSource<'a>(&'a str, Source<&'a str>);
116
117impl<'a> ariadne::Cache<()> for &NamedSource<'a> {
118    type Storage = &'a str;
119
120    fn display<'b>(&self, _: &'b ()) -> Option<Box<dyn std::fmt::Display + 'b>> {
121        Some(Box::new(self.0.to_string()))
122    }
123
124    fn fetch(&mut self, _: &()) -> Result<&Source<Self::Storage>, Box<dyn std::fmt::Debug + '_>> {
125        Ok(&self.1)
126    }
127}
128
129static SCHEMAS: Lazy<(Schemas, SQLDialect)> = Lazy::new(|| {
130    let schema_src = SCHEMA_SRC.as_str();
131    let dialect = if let Some(first_line) = schema_src.lines().next() {
132        if first_line.contains("sql-product: postgres") {
133            SQLDialect::PostgreSQL
134        } else if first_line.contains("sql-product: sqlite") {
135            SQLDialect::Sqlite
136        } else {
137            SQLDialect::MariaDB
138        }
139    } else {
140        SQLDialect::MariaDB
141    };
142
143    let options = TypeOptions::new().dialect(dialect.clone());
144    let mut issues = sql_type::Issues::new(schema_src);
145    let schemas = parse_schemas(schema_src, &mut issues, &options);
146    if !issues.is_ok() {
147        let source = NamedSource("sqlx-type-schema.sql", Source::from(schema_src));
148        let mut err = false;
149        for issue in issues.into_vec() {
150            if issue.level == sql_type::Level::Error {
151                err = true;
152            }
153            let r = issue_to_report_color(issue);
154            r.eprint(&source).unwrap();
155        }
156        if err {
157            panic!("Errors processing sqlx-type-schema.sql");
158        }
159    }
160    (schemas, dialect)
161});
162
163fn quote_args(
164    errors: &mut Vec<proc_macro2::TokenStream>,
165    query: &str,
166    last_span: Span,
167    args: &[Expr],
168    arguments: &[(sql_type::ArgumentKey<'_>, sql_type::FullType)],
169    dialect: &SQLDialect,
170) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
171    let cls = match dialect {
172        SQLDialect::MariaDB => quote!(sqlx::mysql::MySql),
173        SQLDialect::Sqlite => quote!(sqlx::sqlite::Sqlite),
174        SQLDialect::PostgreSQL => quote!(sqlx::postgres::Postgres),
175    };
176
177    let mut at = Vec::new();
178    let inv = sql_type::FullType::invalid();
179    for (k, v) in arguments {
180        match k {
181            sql_type::ArgumentKey::Index(i) => {
182                while at.len() <= *i {
183                    at.push(&inv);
184                }
185                at[*i] = v;
186            }
187            sql_type::ArgumentKey::Identifier(_) => {
188                errors.push(
189                    syn::Error::new(last_span.span(), "Named arguments not supported")
190                        .to_compile_error(),
191                );
192            }
193        }
194    }
195
196    if at.len() > args.len() {
197        errors.push(
198            syn::Error::new(
199                last_span,
200                format!("Expected {} additional arguments", at.len() - args.len()),
201            )
202            .to_compile_error(),
203        );
204    }
205
206    if let Some(args) = args.get(at.len()..) {
207        for arg in args {
208            errors.push(syn::Error::new(arg.span(), "unexpected argument").to_compile_error());
209        }
210    }
211
212    let arg_names = (0..args.len())
213        .map(|i| format_ident!("arg{}", i))
214        .collect::<Vec<_>>();
215
216    let mut arg_bindings = Vec::new();
217    let mut arg_add = Vec::new();
218
219    let mut list_lengths = Vec::new();
220
221    for ((qa, ta), name) in args.iter().zip(at).zip(&arg_names) {
222        let mut t = match ta.t {
223            sql_type::Type::U8 => quote! {u8},
224            sql_type::Type::I8 => quote! {i8},
225            sql_type::Type::U16 => quote! {u16},
226            sql_type::Type::I16 => quote! {i16},
227            sql_type::Type::U32 => quote! {u32},
228            sql_type::Type::I32 => quote! {i32},
229            sql_type::Type::U64 => quote! {u64},
230            sql_type::Type::I64 => quote! {i64},
231            sql_type::Type::Base(sql_type::BaseType::Any) => quote! {sqlx_type::Any},
232            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
233            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {&[u8]},
234            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {sqlx_type::Date},
235            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {sqlx_type::DateTime},
236            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {sqlx_type::Float},
237            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {sqlx_type::Integer},
238            sql_type::Type::Base(sql_type::BaseType::String) => quote! {&str},
239            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("time"),
240            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => quote! {sqlx_type::Timestamp},
241            sql_type::Type::Null => todo!("null"),
242            sql_type::Type::Invalid => quote! {std::convert::Infallible},
243            sql_type::Type::Enum(_) => quote! {&str},
244            sql_type::Type::Set(_) => quote! {&str},
245            sql_type::Type::Args(_, _) => todo!("args"),
246            sql_type::Type::F32 => quote! {f32},
247            sql_type::Type::F64 => quote! {f64},
248            sql_type::Type::JSON => quote! {sqlx_type::Any},
249        };
250        if !ta.not_null {
251            t = quote! {Option<#t>}
252        }
253        let span = qa.span();
254        if ta.list_hack {
255            list_lengths.push(quote!(#name.len()));
256            arg_bindings.push(quote_spanned! {span=>
257                let #name = &(#qa);
258                args_count += #name.len();
259                for v in #name.iter() {
260                    size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(v);
261                }
262                if false {
263                    sqlx_type::check_arg_list_hack::<#t, _>(#name);
264                    ::std::panic!();
265                }
266            });
267            arg_add.push(quote!(
268                for v in #name.iter() {
269                    e = e.and_then(|()| query_args.add(v));
270                }
271            ));
272        } else {
273            arg_bindings.push(quote_spanned! {span=>
274                let #name = &(#qa);
275                args_count += 1;
276                size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(#name);
277                if false {
278                    sqlx_type::check_arg::<#t, _>(#name);
279                    ::std::panic!();
280                }
281            });
282            arg_add.push(quote!(e = e.and_then(|()| query_args.add(#name));));
283        }
284    }
285
286    let query = if list_lengths.is_empty() {
287        quote!(#query)
288    } else {
289        quote!(
290            &sqlx_type::convert_list_query(#query, &[#(#list_lengths),*])
291        )
292    };
293
294    (
295        quote! {
296            let mut size_hints = 0;
297            let mut args_count = 0;
298            #(#arg_bindings)*
299
300            let mut query_args = <#cls as ::sqlx::database::Database>::Arguments::default();
301            query_args.reserve(args_count, size_hints);
302            let mut e = Ok(());
303            #(#arg_add)*
304            let query_args = e.and_then(|()| Ok(query_args));
305        },
306        query,
307    )
308}
309
310fn issues_to_errors(issues: Vec<Issue>, source: &str, span: Span) -> Vec<proc_macro2::TokenStream> {
311    if !issues.is_empty() {
312        let source = NamedSource("", Source::from(source));
313        let mut err = false;
314        let mut out = Vec::new();
315        for issue in issues {
316            if issue.level == sql_type::Level::Error {
317                err = true;
318            }
319            let r = issue_to_report(issue);
320            r.write(&source, &mut out).unwrap();
321        }
322        if err {
323            return vec![syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error()];
324        }
325    }
326    Vec::new()
327}
328
329fn construct_row(
330    columns: &[SelectTypeColumn],
331) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
332    let mut row_members = Vec::new();
333    let mut row_construct = Vec::new();
334    for (i, c) in columns.iter().enumerate() {
335        let mut t = match c.type_.t {
336            sql_type::Type::U8 => quote! {u8},
337            sql_type::Type::I8 => quote! {i8},
338            sql_type::Type::U16 => quote! {u16},
339            sql_type::Type::I16 => quote! {i16},
340            sql_type::Type::U32 => quote! {u32},
341            sql_type::Type::I32 => quote! {i32},
342            sql_type::Type::U64 => quote! {u64},
343            sql_type::Type::I64 => quote! {i64},
344            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
345            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
346            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
347            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
348            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
349            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
350            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
351            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
352            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
353            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
354                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
355            }
356            sql_type::Type::Null => todo!("from_null"),
357            sql_type::Type::Invalid => quote! {i64},
358            sql_type::Type::Enum(_) => quote! {String},
359            sql_type::Type::Set(_) => quote! {String},
360            sql_type::Type::Args(_, _) => todo!("from_args"),
361            sql_type::Type::F32 => quote! {f32},
362            sql_type::Type::F64 => quote! {f64},
363            sql_type::Type::JSON => quote! {String},
364        };
365        let name = match &c.name {
366            Some(v) => v,
367            None => continue,
368        };
369
370        let ident = String::from("r#") + name.value;
371        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
372            ident
373        } else {
374            // TODO error
375            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
376            continue;
377        };
378
379        if !c.type_.not_null {
380            t = quote! {Option<#t>};
381        }
382        row_members.push(quote! {
383            #ident : #t
384        });
385        row_construct.push(quote! {
386            #ident: sqlx::Row::get(&row, #i)
387        });
388    }
389    (row_members, row_construct)
390}
391
392struct Query {
393    query: String,
394    query_span: Span,
395    args: Vec<Expr>,
396    last_span: Span,
397}
398
399impl Parse for Query {
400    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
401        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
402        let query: String = query_.iter().map(LitStr::value).collect();
403        let query_span = query_.span();
404        let mut last_span = query_span;
405        let mut args = Vec::new();
406        while !input.is_empty() {
407            let _ = input.parse::<syn::token::Comma>()?;
408            if input.is_empty() {
409                break;
410            }
411            let arg = input.parse::<Expr>()?;
412            last_span = arg.span();
413            args.push(arg);
414        }
415        Ok(Self {
416            query,
417            query_span,
418            args,
419            last_span,
420        })
421    }
422}
423
424/// Statically checked SQL query, similarly to sqlx::query!.
425///
426/// This expands to an instance of query::Map that outputs an ad-hoc anonymous struct type.
427#[proc_macro]
428pub fn query(input: TokenStream) -> TokenStream {
429    let query = syn::parse_macro_input!(input as Query);
430    let (schemas, dialect) = SCHEMAS.deref();
431    let options = TypeOptions::new()
432        .dialect(dialect.clone())
433        .arguments(match &dialect {
434            SQLDialect::MariaDB => SQLArguments::QuestionMark,
435            SQLDialect::Sqlite => SQLArguments::QuestionMark,
436            SQLDialect::PostgreSQL => SQLArguments::Dollar,
437        })
438        .list_hack(true);
439    let mut issues = sql_type::Issues::new(&query.query);
440    let stmt = type_statement(schemas, &query.query, &mut issues, &options);
441    let sp = SCHEMA_PATH.as_path().to_str().unwrap();
442
443    let mut errors = issues_to_errors(issues.into_vec(), &query.query, query.query_span);
444    match &stmt {
445        sql_type::StatementType::Select { columns, arguments } => {
446            let (args_tokens, q) = quote_args(
447                &mut errors,
448                &query.query,
449                query.last_span,
450                &query.args,
451                arguments,
452                dialect,
453            );
454            let (row_members, row_construct) = construct_row(columns);
455            let s = quote! { {
456                use ::sqlx::Arguments as _;
457                let _ = std::include_bytes!(#sp);
458                #(#errors; )*
459                #args_tokens;
460
461                struct Row {
462                    #(#row_members),*
463                };
464                sqlx::__query_with_result(#q, query_args).map(|row|
465                    Row{
466                        #(#row_construct),*
467                    }
468                )
469            }};
470            s.into()
471        }
472        sql_type::StatementType::Delete {
473            arguments,
474            returning,
475        } => {
476            let (args_tokens, q) = quote_args(
477                &mut errors,
478                &query.query,
479                query.last_span,
480                &query.args,
481                arguments,
482                dialect,
483            );
484            let s = match returning.as_ref() {
485                Some(returning) => {
486                    let (row_members, row_construct) = construct_row(returning);
487                    quote! { {
488                        use ::sqlx::Arguments as _;
489                        let _ = std::include_bytes!(#sp);
490                        #(#errors; )*
491                        #args_tokens
492
493                        struct Row {
494                            #(#row_members),*
495                        };
496                        sqlx::__query_with_result(#q, query_args).map(|row|
497                            Row{
498                                #(#row_construct),*
499                            }
500                        )
501                    }}
502                }
503                None => quote! { {
504                    use ::sqlx::Arguments as _;
505                    #(#errors; )*
506                    #args_tokens
507                    sqlx::__query_with_result(#q, query_args)
508                }
509                },
510            };
511            s.into()
512        }
513        sql_type::StatementType::Insert {
514            arguments,
515            returning,
516            ..
517        } => {
518            let (args_tokens, q) = quote_args(
519                &mut errors,
520                &query.query,
521                query.last_span,
522                &query.args,
523                arguments,
524                dialect,
525            );
526            let s = match returning.as_ref() {
527                Some(returning) => {
528                    let (row_members, row_construct) = construct_row(returning);
529                    quote! { {
530                        use ::sqlx::Arguments as _;
531                        let _ = std::include_bytes!(#sp);
532                        #(#errors; )*
533                        #args_tokens
534
535                        struct Row {
536                            #(#row_members),*
537                        };
538                        sqlx::__query_with_result(#q, query_args).map(|row|
539                            Row{
540                                #(#row_construct),*
541                            }
542                        )
543                    }}
544                }
545                None => quote! { {
546                    use ::sqlx::Arguments as _;
547                    #(#errors; )*
548                    #args_tokens
549                    sqlx::__query_with_result(#q, query_args)
550                }
551                },
552            };
553            s.into()
554        }
555        sql_type::StatementType::Update { arguments } => {
556            let (args_tokens, q) = quote_args(
557                &mut errors,
558                &query.query,
559                query.last_span,
560                &query.args,
561                arguments,
562                dialect,
563            );
564            let s = quote! { {
565                use ::sqlx::Arguments as _;
566                #(#errors; )*
567                #args_tokens
568                sqlx::__query_with_result(#q, query_args)
569            }
570            };
571            s.into()
572        }
573        sql_type::StatementType::Replace {
574            arguments,
575            returning,
576        } => {
577            let (args_tokens, q) = quote_args(
578                &mut errors,
579                &query.query,
580                query.last_span,
581                &query.args,
582                arguments,
583                dialect,
584            );
585            let s = match returning.as_ref() {
586                Some(returning) => {
587                    let (row_members, row_construct) = construct_row(returning);
588                    quote! { {
589                        use ::sqlx::Arguments as _;
590                        let _ = std::include_bytes!(#sp);
591                        #(#errors; )*
592                        #args_tokens
593
594                        struct Row {
595                            #(#row_members),*
596                        };
597                        sqlx::__query_with_result(#q, query_args).map(|row|
598                            Row{
599                                #(#row_construct),*
600                            }
601                        )
602                    }}
603                }
604                None => quote! { {
605                    use ::sqlx::Arguments as _;
606                    #(#errors; )*
607                    #args_tokens
608                    sqlx::__query_with_result(#q, query_args)
609                }
610                },
611            };
612            s.into()
613        }
614        sql_type::StatementType::Invalid => {
615            let s = quote! { {
616                #(#errors; )*;
617                todo!("Invalid")
618            }};
619            s.into()
620        }
621    }
622}
623
624fn construct_row2(columns: &[SelectTypeColumn]) -> Vec<proc_macro2::TokenStream> {
625    let mut row_construct = Vec::new();
626    for (i, c) in columns.iter().enumerate() {
627        let mut t = match c.type_.t {
628            sql_type::Type::U8 => quote! {u8},
629            sql_type::Type::I8 => quote! {i8},
630            sql_type::Type::U16 => quote! {u16},
631            sql_type::Type::I16 => quote! {i16},
632            sql_type::Type::U32 => quote! {u32},
633            sql_type::Type::I32 => quote! {i32},
634            sql_type::Type::U64 => quote! {u64},
635            sql_type::Type::I64 => quote! {i64},
636            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
637            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
638            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
639            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
640            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
641            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
642            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
643            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
644            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
645            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
646                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
647            }
648            sql_type::Type::Null => todo!("from_null"),
649            sql_type::Type::Invalid => quote! {i64},
650            sql_type::Type::Enum(_) => quote! {String},
651            sql_type::Type::Set(_) => quote! {String},
652            sql_type::Type::Args(_, _) => todo!("from_args"),
653            sql_type::Type::F32 => quote! {f32},
654            sql_type::Type::F64 => quote! {f64},
655            sql_type::Type::JSON => quote! {String},
656        };
657        let name = match &c.name {
658            Some(v) => v,
659            None => continue,
660        };
661
662        let ident = String::from("r#") + name.value;
663        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
664            ident
665        } else {
666            // TODO error
667            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
668            continue;
669        };
670
671        if !c.type_.not_null {
672            t = quote! {Option<#t>};
673        }
674        row_construct.push(quote! {
675            #ident: sqlx_type::arg_out::<#t, _, #i>(sqlx::Row::get(&row, #i))
676        });
677    }
678    row_construct
679}
680
681struct QueryAs {
682    as_: Ident,
683    query: String,
684    query_span: Span,
685    args: Vec<Expr>,
686    last_span: Span,
687}
688
689impl Parse for QueryAs {
690    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
691        let as_ = input.parse::<Ident>()?;
692        let _ = input.parse::<syn::token::Comma>()?;
693
694        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
695        let query: String = query_.iter().map(LitStr::value).collect();
696        let query_span = query_.span();
697
698        let mut last_span = query_span;
699        let mut args = Vec::new();
700        while !input.is_empty() {
701            let _ = input.parse::<syn::token::Comma>()?;
702            if input.is_empty() {
703                break;
704            }
705            let arg = input.parse::<Expr>()?;
706            last_span = arg.span();
707            args.push(arg);
708        }
709        Ok(Self {
710            as_,
711            query,
712            query_span,
713            args,
714            last_span,
715        })
716    }
717}
718
719/// A variant of query! which takes a path to an explicitly defined struct as the output type.
720///
721/// This lets you return the struct from a function or add your own trait implementations.
722#[proc_macro]
723pub fn query_as(input: TokenStream) -> TokenStream {
724    let query_as = syn::parse_macro_input!(input as QueryAs);
725    let (schemas, dialect) = SCHEMAS.deref();
726    let options = TypeOptions::new()
727        .dialect(dialect.clone())
728        .arguments(match &dialect {
729            SQLDialect::MariaDB => SQLArguments::QuestionMark,
730            SQLDialect::Sqlite => SQLArguments::QuestionMark,
731            SQLDialect::PostgreSQL => SQLArguments::Dollar,
732        })
733        .list_hack(true);
734    let mut issues = sql_type::Issues::new(&query_as.query);
735    let stmt = type_statement(schemas, &query_as.query, &mut issues, &options);
736
737    let mut errors = issues_to_errors(issues.into_vec(), &query_as.query, query_as.query_span);
738    match &stmt {
739        sql_type::StatementType::Select { columns, arguments } => {
740            let (args_tokens, q) = quote_args(
741                &mut errors,
742                &query_as.query,
743                query_as.last_span,
744                &query_as.args,
745                arguments,
746                dialect,
747            );
748
749            let row_construct = construct_row2(columns);
750            let row = query_as.as_;
751            let s = quote! { {
752                use ::sqlx::Arguments as _;
753                #(#errors; )*
754                #args_tokens
755                sqlx::__query_with_result(#q, query_args).map(|row|
756                    #row{
757                        #(#row_construct),*
758                    }
759                )
760            }};
761            //println!("TOKENS: {}", s);
762            s.into()
763        }
764        sql_type::StatementType::Delete { .. } => {
765            errors.push(
766                syn::Error::new(query_as.query_span, "DELETE not support in query_as")
767                    .to_compile_error(),
768            );
769            quote! { {
770                #(#errors; )*
771                todo!("delete")
772            }}
773            .into()
774        }
775        sql_type::StatementType::Insert {
776            returning: None, ..
777        } => {
778            errors.push(
779                syn::Error::new(
780                    query_as.query_span,
781                    "INSERT without RETURNING not support in query_as",
782                )
783                .to_compile_error(),
784            );
785            quote! { {
786                #(#errors; )*
787                todo!("insert")
788            }}
789            .into()
790        }
791        sql_type::StatementType::Insert {
792            arguments,
793            returning: Some(returning),
794            ..
795        } => {
796            let (args_tokens, q) = quote_args(
797                &mut errors,
798                &query_as.query,
799                query_as.last_span,
800                &query_as.args,
801                arguments,
802                dialect,
803            );
804
805            let row_construct = construct_row2(returning);
806            let row = query_as.as_;
807            let s = quote! { {
808                use ::sqlx::Arguments as _;
809                #(#errors; )*
810                #args_tokens
811                sqlx::__query_with_result(#q, query_args).map(|row|
812                    #row{
813                        #(#row_construct),*
814                    }
815                )
816            }};
817            s.into()
818        }
819        sql_type::StatementType::Update { .. } => {
820            errors.push(
821                syn::Error::new(query_as.query_span, "UPDATE not support in query_as")
822                    .to_compile_error(),
823            );
824            quote! { {
825                #(#errors; )*
826                todo!("update")
827            }}
828            .into()
829        }
830        sql_type::StatementType::Replace {
831            returning: None, ..
832        } => {
833            errors.push(
834                syn::Error::new(
835                    query_as.query_span,
836                    "REPLACE without RETURNING not support in query_as",
837                )
838                .to_compile_error(),
839            );
840            quote! { {
841                #(#errors; )*
842                todo!("replace")
843            }}
844            .into()
845        }
846        sql_type::StatementType::Replace {
847            arguments,
848            returning: Some(returning),
849            ..
850        } => {
851            let (args_tokens, q) = quote_args(
852                &mut errors,
853                &query_as.query,
854                query_as.last_span,
855                &query_as.args,
856                arguments,
857                dialect,
858            );
859
860            let row_construct = construct_row2(returning);
861            let row = query_as.as_;
862            let s = quote! { {
863                use ::sqlx::Arguments as _;
864                #(#errors; )*
865                #args_tokens
866                sqlx::__query_with_result(#q, query_args).map(|row|
867                    #row{
868                        #(#row_construct),*
869                    }
870                )
871            }};
872            s.into()
873        }
874        sql_type::StatementType::Invalid => quote! { {
875            #(#errors; )*;
876            todo!("invalid")
877        }}
878        .into(),
879    }
880}