sqlx_type_macro/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::ops::Deref;
4use std::path::PathBuf;
5
6use ariadne::{Color, Label, Report, ReportKind, Source};
7use once_cell::sync::Lazy;
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::{format_ident, quote, quote_spanned};
11use sql_type::schema::{parse_schemas, Schemas};
12use sql_type::{type_statement, Issue, SQLArguments, SQLDialect, SelectTypeColumn, TypeOptions};
13use syn::spanned::Spanned;
14use syn::{parse::Parse, punctuated::Punctuated, Expr, Ident, LitStr, Token};
15
16static SCHEMA_PATH: Lazy<PathBuf> = Lazy::new(|| {
17    let mut schema_path: PathBuf = std::env::var("CARGO_MANIFEST_DIR")
18        .expect("`CARGO_schema_path` must be set")
19        .into();
20
21    schema_path.push("sqlx-type-schema.sql");
22
23    if !schema_path.exists() {
24        use serde::Deserialize;
25        use std::process::Command;
26
27        let cargo = std::env::var("CARGO").expect("`CARGO` must be set");
28        schema_path.pop();
29
30        let output = Command::new(cargo)
31            .args(["metadata", "--format-version=1"])
32            .current_dir(&schema_path)
33            .env_remove("__CARGO_FIX_PLZ")
34            .output()
35            .expect("Could not fetch metadata");
36
37        #[derive(Deserialize)]
38        struct CargoMetadata {
39            workspace_root: PathBuf,
40        }
41
42        let metadata: CargoMetadata =
43            serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output");
44
45        schema_path = metadata.workspace_root;
46        schema_path.push("sqlx-type-schema.sql");
47    }
48    if !schema_path.exists() {
49        panic!("Unable to locate sqlx-type-schema.sql");
50    }
51    schema_path
52});
53
54// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
55// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
56static SCHEMA_SRC: Lazy<String> =
57    Lazy::new(|| match std::fs::read_to_string(SCHEMA_PATH.as_path()) {
58        Ok(v) => v,
59        Err(e) => panic!(
60            "Unable to read schema from {:?}: {}",
61            SCHEMA_PATH.as_path(),
62            e
63        ),
64    });
65
66fn issue_to_report(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
67    let mut builder = Report::build(
68        match issue.level {
69            sql_type::Level::Warning => ReportKind::Warning,
70            sql_type::Level::Error => ReportKind::Error,
71        },
72        issue.span.clone(),
73    )
74    .with_config(ariadne::Config::default().with_color(false))
75    .with_label(
76        Label::new(issue.span)
77            .with_order(-1)
78            .with_priority(-1)
79            .with_message(issue.message),
80    );
81    for frag in issue.fragments {
82        builder = builder.with_label(Label::new(frag.span).with_message(frag.message));
83    }
84    builder.finish()
85}
86
87fn issue_to_report_color(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
88    let mut builder = Report::build(
89        match issue.level {
90            sql_type::Level::Warning => ReportKind::Warning,
91            sql_type::Level::Error => ReportKind::Error,
92        },
93        issue.span.clone(),
94    )
95    .with_label(
96        Label::new(issue.span)
97            .with_color(match issue.level {
98                sql_type::Level::Warning => Color::Yellow,
99                sql_type::Level::Error => Color::Red,
100            })
101            .with_order(-1)
102            .with_priority(-1)
103            .with_message(issue.message),
104    );
105    for frag in issue.fragments {
106        builder = builder.with_label(
107            Label::new(frag.span)
108                .with_color(Color::Blue)
109                .with_message(frag.message),
110        );
111    }
112    builder.finish()
113}
114
115struct NamedSource<'a>(&'a str, Source<&'a str>);
116
117impl<'a> ariadne::Cache<()> for &NamedSource<'a> {
118    type Storage = &'a str;
119
120    fn display<'b>(&self, _: &'b ()) -> Option<impl std::fmt::Display + 'b> {
121        Some(self.0.to_string())
122    }
123
124    fn fetch(&mut self, _: &()) -> Result<&Source<Self::Storage>, impl std::fmt::Debug> {
125        Ok::<_,()>(&self.1)
126    }
127}
128
129static SCHEMAS: Lazy<(Schemas, SQLDialect)> = Lazy::new(|| {
130    let schema_src = SCHEMA_SRC.as_str();
131    let dialect = if let Some(first_line) = schema_src.lines().next() {
132        if first_line.contains("sql-product: postgres") {
133            SQLDialect::PostgreSQL
134        } else if first_line.contains("sql-product: sqlite") {
135            SQLDialect::Sqlite
136        } else {
137            SQLDialect::MariaDB
138        }
139    } else {
140        SQLDialect::MariaDB
141    };
142
143    let options = TypeOptions::new().dialect(dialect.clone());
144    let mut issues = sql_type::Issues::new(schema_src);
145    let schemas = parse_schemas(schema_src, &mut issues, &options);
146    if !issues.is_ok() {
147        let source = NamedSource("sqlx-type-schema.sql", Source::from(schema_src));
148        let mut err = false;
149        for issue in issues.into_vec() {
150            if issue.level == sql_type::Level::Error {
151                err = true;
152            }
153            let r = issue_to_report_color(issue);
154            r.eprint(&source).unwrap();
155        }
156        if err {
157            panic!("Errors processing sqlx-type-schema.sql");
158        }
159    }
160    (schemas, dialect)
161});
162
163fn quote_args(
164    errors: &mut Vec<proc_macro2::TokenStream>,
165    query: &str,
166    last_span: Span,
167    args: &[Expr],
168    arguments: &[(sql_type::ArgumentKey<'_>, sql_type::FullType)],
169    dialect: &SQLDialect,
170) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
171    let cls = match dialect {
172        SQLDialect::MariaDB => quote!(sqlx::mysql::MySql),
173        SQLDialect::Sqlite => quote!(sqlx::sqlite::Sqlite),
174        SQLDialect::PostgreSQL => quote!(sqlx::postgres::Postgres),
175    };
176
177    let mut at = Vec::new();
178    let inv = sql_type::FullType::invalid();
179    for (k, v) in arguments {
180        match k {
181            sql_type::ArgumentKey::Index(i) => {
182                while at.len() <= *i {
183                    at.push(&inv);
184                }
185                at[*i] = v;
186            }
187            sql_type::ArgumentKey::Identifier(_) => {
188                errors.push(
189                    syn::Error::new(last_span.span(), "Named arguments not supported")
190                        .to_compile_error(),
191                );
192            }
193        }
194    }
195
196    if at.len() > args.len() {
197        errors.push(
198            syn::Error::new(
199                last_span,
200                format!("Expected {} additional arguments", at.len() - args.len()),
201            )
202            .to_compile_error(),
203        );
204    }
205
206    if let Some(args) = args.get(at.len()..) {
207        for arg in args {
208            errors.push(syn::Error::new(arg.span(), "unexpected argument").to_compile_error());
209        }
210    }
211
212    let arg_names = (0..args.len())
213        .map(|i| format_ident!("arg{}", i))
214        .collect::<Vec<_>>();
215
216    let mut arg_bindings = Vec::new();
217    let mut arg_add = Vec::new();
218
219    let mut list_lengths = Vec::new();
220
221    for ((qa, ta), name) in args.iter().zip(at).zip(&arg_names) {
222        let mut t = match ta.t {
223            sql_type::Type::U8 => quote! {u8},
224            sql_type::Type::I8 => quote! {i8},
225            sql_type::Type::U16 => quote! {u16},
226            sql_type::Type::I16 => quote! {i16},
227            sql_type::Type::U32 => quote! {u32},
228            sql_type::Type::I32 => quote! {i32},
229            sql_type::Type::U64 => quote! {u64},
230            sql_type::Type::I64 => quote! {i64},
231            sql_type::Type::Base(sql_type::BaseType::Any) => quote! {sqlx_type::Any},
232            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
233            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {&[u8]},
234            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {sqlx_type::Date},
235            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {sqlx_type::DateTime},
236            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {sqlx_type::Float},
237            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {sqlx_type::Integer},
238            sql_type::Type::Base(sql_type::BaseType::String) => quote! {&str},
239            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("time"),
240            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => quote! {sqlx_type::Timestamp},
241            sql_type::Type::Null => todo!("null"),
242            sql_type::Type::Invalid => quote! {std::convert::Infallible},
243            sql_type::Type::Enum(_) => quote! {&str},
244            sql_type::Type::Set(_) => quote! {&str},
245            sql_type::Type::Args(_, _) => todo!("args"),
246            sql_type::Type::F32 => quote! {f32},
247            sql_type::Type::F64 => quote! {f64},
248            sql_type::Type::JSON => quote! {sqlx_type::Any},
249        };
250        if !ta.not_null {
251            t = quote! {Option<#t>}
252        }
253        let span = qa.span();
254        if ta.list_hack {
255            list_lengths.push(quote!(#name.len()));
256            arg_bindings.push(quote_spanned! {span=>
257                let #name = &(#qa);
258                args_count += #name.len();
259                for v in #name.iter() {
260                    size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(v);
261                }
262                if false {
263                    sqlx_type::check_arg_list_hack::<#t, _>(#name);
264                    ::std::panic!();
265                }
266            });
267            arg_add.push(quote!(
268                for v in #name.iter() {
269                    e = e.and_then(|()| query_args.add(v));
270                }
271            ));
272        } else {
273            arg_bindings.push(quote_spanned! {span=>
274                let #name = &(#qa);
275                args_count += 1;
276                size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(#name);
277                if false {
278                    sqlx_type::check_arg::<#t, _>(#name);
279                    ::std::panic!();
280                }
281            });
282            arg_add.push(quote!(e = e.and_then(|()| query_args.add(#name));));
283        }
284    }
285
286    let query = if list_lengths.is_empty() {
287        quote!(#query)
288    } else {
289        quote!(
290            &sqlx_type::convert_list_query(#query, &[#(#list_lengths),*])
291        )
292    };
293
294    (
295        quote! {
296            let mut size_hints = 0;
297            let mut args_count = 0;
298            #(#arg_bindings)*
299
300            let mut query_args = <#cls as ::sqlx::database::Database>::Arguments::default();
301            query_args.reserve(args_count, size_hints);
302            let mut e = Ok(());
303            #(#arg_add)*
304            let query_args = e.and_then(|()| Ok(query_args));
305        },
306        query,
307    )
308}
309
310fn issues_to_errors(issues: Vec<Issue>, source: &str, span: Span) -> Vec<proc_macro2::TokenStream> {
311    if !issues.is_empty() {
312        let source = NamedSource("", Source::from(source));
313        let mut err = false;
314        let mut out = Vec::new();
315        for issue in issues {
316            if issue.level == sql_type::Level::Error {
317                err = true;
318            }
319            let r = issue_to_report(issue);
320            r.write(&source, &mut out).unwrap();
321        }
322        if err {
323            return vec![syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error()];
324        }
325    }
326    Vec::new()
327}
328
329fn construct_row(
330    columns: &[SelectTypeColumn],
331) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
332    let mut row_members = Vec::new();
333    let mut row_construct = Vec::new();
334    for (i, c) in columns.iter().enumerate() {
335        let mut t = match c.type_.t {
336            sql_type::Type::U8 => quote! {u8},
337            sql_type::Type::I8 => quote! {i8},
338            sql_type::Type::U16 => quote! {u16},
339            sql_type::Type::I16 => quote! {i16},
340            sql_type::Type::U32 => quote! {u32},
341            sql_type::Type::I32 => quote! {i32},
342            sql_type::Type::U64 => quote! {u64},
343            sql_type::Type::I64 => quote! {i64},
344            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
345            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
346            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
347            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
348            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
349            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
350            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
351            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
352            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
353            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
354                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
355            }
356            sql_type::Type::Null => todo!("from_null"),
357            sql_type::Type::Invalid => quote! {i64},
358            sql_type::Type::Enum(_) => quote! {String},
359            sql_type::Type::Set(_) => quote! {String},
360            sql_type::Type::Args(_, _) => todo!("from_args"),
361            sql_type::Type::F32 => quote! {f32},
362            sql_type::Type::F64 => quote! {f64},
363            sql_type::Type::JSON => quote! {String},
364        };
365        let name = match &c.name {
366            Some(v) => v,
367            None => continue,
368        };
369
370        let ident = String::from("r#") + name.value;
371        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
372            ident
373        } else {
374            // TODO error
375            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
376            continue;
377        };
378
379        if !c.type_.not_null {
380            t = quote! {Option<#t>};
381        }
382        row_members.push(quote! {
383            #ident : #t
384        });
385        row_construct.push(quote! {
386            #ident: sqlx::Row::get(&row, #i)
387        });
388    }
389    (row_members, row_construct)
390}
391
392struct Query {
393    query: String,
394    query_span: Span,
395    args: Vec<Expr>,
396    last_span: Span,
397}
398
399impl Parse for Query {
400    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
401        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
402        let query: String = query_.iter().map(LitStr::value).collect();
403        let query_span = query_.span();
404        let mut last_span = query_span;
405        let mut args = Vec::new();
406        while !input.is_empty() {
407            let _ = input.parse::<syn::token::Comma>()?;
408            if input.is_empty() {
409                break;
410            }
411            let arg = input.parse::<Expr>()?;
412            last_span = arg.span();
413            args.push(arg);
414        }
415        Ok(Self {
416            query,
417            query_span,
418            args,
419            last_span,
420        })
421    }
422}
423
424/// Statically checked SQL query, similarly to sqlx::query!.
425///
426/// This expands to an instance of query::Map that outputs an ad-hoc anonymous struct type.
427#[proc_macro]
428pub fn query(input: TokenStream) -> TokenStream {
429    let query = syn::parse_macro_input!(input as Query);
430    let (schemas, dialect) = SCHEMAS.deref();
431    let options = TypeOptions::new()
432        .dialect(dialect.clone())
433        .arguments(match &dialect {
434            SQLDialect::MariaDB => SQLArguments::QuestionMark,
435            SQLDialect::Sqlite => SQLArguments::QuestionMark,
436            SQLDialect::PostgreSQL => SQLArguments::Dollar,
437        })
438        .list_hack(true);
439    let mut issues = sql_type::Issues::new(&query.query);
440    let stmt = type_statement(schemas, &query.query, &mut issues, &options);
441    let sp = SCHEMA_PATH.as_path().to_str().unwrap();
442
443    let mut errors = issues_to_errors(issues.into_vec(), &query.query, query.query_span);
444    match &stmt {
445        sql_type::StatementType::Select { columns, arguments } => {
446            let (args_tokens, q) = quote_args(
447                &mut errors,
448                &query.query,
449                query.last_span,
450                &query.args,
451                arguments,
452                dialect,
453            );
454            let (row_members, row_construct) = construct_row(columns);
455            let s = quote! { {
456                use ::sqlx::Arguments as _;
457                let _ = std::include_bytes!(#sp);
458                #(#errors; )*
459                #args_tokens;
460
461                struct Row {
462                    #(#row_members),*
463                };
464                sqlx::__query_with_result(#q, query_args).map(|row|
465                    Row{
466                        #(#row_construct),*
467                    }
468                )
469            }};
470            s.into()
471        }
472        sql_type::StatementType::Delete {
473            arguments,
474            returning,
475        } => {
476            let (args_tokens, q) = quote_args(
477                &mut errors,
478                &query.query,
479                query.last_span,
480                &query.args,
481                arguments,
482                dialect,
483            );
484            let s = match returning.as_ref() {
485                Some(returning) => {
486                    let (row_members, row_construct) = construct_row(returning);
487                    quote! { {
488                        use ::sqlx::Arguments as _;
489                        let _ = std::include_bytes!(#sp);
490                        #(#errors; )*
491                        #args_tokens
492
493                        struct Row {
494                            #(#row_members),*
495                        };
496                        sqlx::__query_with_result(#q, query_args).map(|row|
497                            Row{
498                                #(#row_construct),*
499                            }
500                        )
501                    }}
502                }
503                None => quote! { {
504                    use ::sqlx::Arguments as _;
505                    #(#errors; )*
506                    #args_tokens
507                    sqlx::__query_with_result(#q, query_args)
508                }
509                },
510            };
511            s.into()
512        }
513        sql_type::StatementType::Insert {
514            arguments,
515            returning,
516            ..
517        } => {
518            let (args_tokens, q) = quote_args(
519                &mut errors,
520                &query.query,
521                query.last_span,
522                &query.args,
523                arguments,
524                dialect,
525            );
526            let s = match returning.as_ref() {
527                Some(returning) => {
528                    let (row_members, row_construct) = construct_row(returning);
529                    quote! { {
530                        use ::sqlx::Arguments as _;
531                        let _ = std::include_bytes!(#sp);
532                        #(#errors; )*
533                        #args_tokens
534
535                        struct Row {
536                            #(#row_members),*
537                        };
538                        sqlx::__query_with_result(#q, query_args).map(|row|
539                            Row{
540                                #(#row_construct),*
541                            }
542                        )
543                    }}
544                }
545                None => quote! { {
546                    use ::sqlx::Arguments as _;
547                    #(#errors; )*
548                    #args_tokens
549                    sqlx::__query_with_result(#q, query_args)
550                }
551                },
552            };
553            s.into()
554        }
555        sql_type::StatementType::Update { arguments , returning} => {
556            let (args_tokens, q) = quote_args(
557                &mut errors,
558                &query.query,
559                query.last_span,
560                &query.args,
561                arguments,
562                dialect,
563            );
564
565             let s = match returning.as_ref() {
566                Some(returning) => {
567                    let (row_members, row_construct) = construct_row(returning);
568                    quote! { {
569                        use ::sqlx::Arguments as _;
570                        let _ = std::include_bytes!(#sp);
571                        #(#errors; )*
572                        #args_tokens
573
574                        struct Row {
575                            #(#row_members),*
576                        };
577                        sqlx::__query_with_result(#q, query_args).map(|row|
578                            Row{
579                                #(#row_construct),*
580                            }
581                        )
582                    }}
583                }
584                None =>  quote! { {
585                    use ::sqlx::Arguments as _;
586                    #(#errors; )*
587                    #args_tokens
588                    sqlx::__query_with_result(#q, query_args)
589                }
590                },
591            };
592            s.into()
593        }
594        sql_type::StatementType::Replace {
595            arguments,
596            returning,
597        } => {
598            let (args_tokens, q) = quote_args(
599                &mut errors,
600                &query.query,
601                query.last_span,
602                &query.args,
603                arguments,
604                dialect,
605            );
606            let s = match returning.as_ref() {
607                Some(returning) => {
608                    let (row_members, row_construct) = construct_row(returning);
609                    quote! { {
610                        use ::sqlx::Arguments as _;
611                        let _ = std::include_bytes!(#sp);
612                        #(#errors; )*
613                        #args_tokens
614
615                        struct Row {
616                            #(#row_members),*
617                        };
618                        sqlx::__query_with_result(#q, query_args).map(|row|
619                            Row{
620                                #(#row_construct),*
621                            }
622                        )
623                    }}
624                }
625                None => quote! { {
626                    use ::sqlx::Arguments as _;
627                    #(#errors; )*
628                    #args_tokens
629                    sqlx::__query_with_result(#q, query_args)
630                }
631                },
632            };
633            s.into()
634        }
635        sql_type::StatementType::Invalid => {
636            let s = quote! { {
637                #(#errors; )*;
638                todo!("Invalid")
639            }};
640            s.into()
641        }
642    }
643}
644
645fn construct_row2(columns: &[SelectTypeColumn]) -> Vec<proc_macro2::TokenStream> {
646    let mut row_construct = Vec::new();
647    for (i, c) in columns.iter().enumerate() {
648        let mut t = match c.type_.t {
649            sql_type::Type::U8 => quote! {u8},
650            sql_type::Type::I8 => quote! {i8},
651            sql_type::Type::U16 => quote! {u16},
652            sql_type::Type::I16 => quote! {i16},
653            sql_type::Type::U32 => quote! {u32},
654            sql_type::Type::I32 => quote! {i32},
655            sql_type::Type::U64 => quote! {u64},
656            sql_type::Type::I64 => quote! {i64},
657            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
658            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
659            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
660            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
661            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
662            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
663            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
664            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
665            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
666            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
667                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
668            }
669            sql_type::Type::Null => todo!("from_null"),
670            sql_type::Type::Invalid => quote! {i64},
671            sql_type::Type::Enum(_) => quote! {String},
672            sql_type::Type::Set(_) => quote! {String},
673            sql_type::Type::Args(_, _) => todo!("from_args"),
674            sql_type::Type::F32 => quote! {f32},
675            sql_type::Type::F64 => quote! {f64},
676            sql_type::Type::JSON => quote! {String},
677        };
678        let name = match &c.name {
679            Some(v) => v,
680            None => continue,
681        };
682
683        let ident = String::from("r#") + name.value;
684        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
685            ident
686        } else {
687            // TODO error
688            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
689            continue;
690        };
691
692        if !c.type_.not_null {
693            t = quote! {Option<#t>};
694        }
695        row_construct.push(quote! {
696            #ident: sqlx_type::arg_out::<#t, _, #i>(sqlx::Row::get(&row, #i))
697        });
698    }
699    row_construct
700}
701
702struct QueryAs {
703    as_: Ident,
704    query: String,
705    query_span: Span,
706    args: Vec<Expr>,
707    last_span: Span,
708}
709
710impl Parse for QueryAs {
711    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
712        let as_ = input.parse::<Ident>()?;
713        let _ = input.parse::<syn::token::Comma>()?;
714
715        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
716        let query: String = query_.iter().map(LitStr::value).collect();
717        let query_span = query_.span();
718
719        let mut last_span = query_span;
720        let mut args = Vec::new();
721        while !input.is_empty() {
722            let _ = input.parse::<syn::token::Comma>()?;
723            if input.is_empty() {
724                break;
725            }
726            let arg = input.parse::<Expr>()?;
727            last_span = arg.span();
728            args.push(arg);
729        }
730        Ok(Self {
731            as_,
732            query,
733            query_span,
734            args,
735            last_span,
736        })
737    }
738}
739
740/// A variant of query! which takes a path to an explicitly defined struct as the output type.
741///
742/// This lets you return the struct from a function or add your own trait implementations.
743#[proc_macro]
744pub fn query_as(input: TokenStream) -> TokenStream {
745    let query_as = syn::parse_macro_input!(input as QueryAs);
746    let (schemas, dialect) = SCHEMAS.deref();
747    let options = TypeOptions::new()
748        .dialect(dialect.clone())
749        .arguments(match &dialect {
750            SQLDialect::MariaDB => SQLArguments::QuestionMark,
751            SQLDialect::Sqlite => SQLArguments::QuestionMark,
752            SQLDialect::PostgreSQL => SQLArguments::Dollar,
753        })
754        .list_hack(true);
755    let mut issues = sql_type::Issues::new(&query_as.query);
756    let stmt = type_statement(schemas, &query_as.query, &mut issues, &options);
757
758    let mut errors = issues_to_errors(issues.into_vec(), &query_as.query, query_as.query_span);
759    match &stmt {
760        sql_type::StatementType::Select { columns, arguments } => {
761            let (args_tokens, q) = quote_args(
762                &mut errors,
763                &query_as.query,
764                query_as.last_span,
765                &query_as.args,
766                arguments,
767                dialect,
768            );
769
770            let row_construct = construct_row2(columns);
771            let row = query_as.as_;
772            let s = quote! { {
773                use ::sqlx::Arguments as _;
774                #(#errors; )*
775                #args_tokens
776                sqlx::__query_with_result(#q, query_args).map(|row|
777                    #row{
778                        #(#row_construct),*
779                    }
780                )
781            }};
782            //println!("TOKENS: {}", s);
783            s.into()
784        }
785        sql_type::StatementType::Delete { .. } => {
786            errors.push(
787                syn::Error::new(query_as.query_span, "DELETE not support in query_as")
788                    .to_compile_error(),
789            );
790            quote! { {
791                #(#errors; )*
792                todo!("delete")
793            }}
794            .into()
795        }
796        sql_type::StatementType::Insert {
797            returning: None, ..
798        } => {
799            errors.push(
800                syn::Error::new(
801                    query_as.query_span,
802                    "INSERT without RETURNING not support in query_as",
803                )
804                .to_compile_error(),
805            );
806            quote! { {
807                #(#errors; )*
808                todo!("insert")
809            }}
810            .into()
811        }
812        sql_type::StatementType::Insert {
813            arguments,
814            returning: Some(returning),
815            ..
816        } => {
817            let (args_tokens, q) = quote_args(
818                &mut errors,
819                &query_as.query,
820                query_as.last_span,
821                &query_as.args,
822                arguments,
823                dialect,
824            );
825
826            let row_construct = construct_row2(returning);
827            let row = query_as.as_;
828            let s = quote! { {
829                use ::sqlx::Arguments as _;
830                #(#errors; )*
831                #args_tokens
832                sqlx::__query_with_result(#q, query_args).map(|row|
833                    #row{
834                        #(#row_construct),*
835                    }
836                )
837            }};
838            s.into()
839        }
840        sql_type::StatementType::Update { .. } => {
841            errors.push(
842                syn::Error::new(query_as.query_span, "UPDATE not support in query_as")
843                    .to_compile_error(),
844            );
845            quote! { {
846                #(#errors; )*
847                todo!("update")
848            }}
849            .into()
850        }
851        sql_type::StatementType::Replace {
852            returning: None, ..
853        } => {
854            errors.push(
855                syn::Error::new(
856                    query_as.query_span,
857                    "REPLACE without RETURNING not support in query_as",
858                )
859                .to_compile_error(),
860            );
861            quote! { {
862                #(#errors; )*
863                todo!("replace")
864            }}
865            .into()
866        }
867        sql_type::StatementType::Replace {
868            arguments,
869            returning: Some(returning),
870            ..
871        } => {
872            let (args_tokens, q) = quote_args(
873                &mut errors,
874                &query_as.query,
875                query_as.last_span,
876                &query_as.args,
877                arguments,
878                dialect,
879            );
880
881            let row_construct = construct_row2(returning);
882            let row = query_as.as_;
883            let s = quote! { {
884                use ::sqlx::Arguments as _;
885                #(#errors; )*
886                #args_tokens
887                sqlx::__query_with_result(#q, query_args).map(|row|
888                    #row{
889                        #(#row_construct),*
890                    }
891                )
892            }};
893            s.into()
894        }
895        sql_type::StatementType::Invalid => quote! { {
896            #(#errors; )*;
897            todo!("invalid")
898        }}
899        .into(),
900    }
901}