sqlx_type_macro/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::ops::Deref;
4use std::path::PathBuf;
5
6use ariadne::{Color, Label, Report, ReportKind, Source};
7use once_cell::sync::Lazy;
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::{format_ident, quote, quote_spanned};
11use sql_type::schema::{parse_schemas, Schemas};
12use sql_type::{type_statement, Issue, SQLArguments, SQLDialect, SelectTypeColumn, TypeOptions};
13use syn::spanned::Spanned;
14use syn::{parse::Parse, punctuated::Punctuated, Expr, Ident, LitStr, Token};
15
16static SCHEMA_PATH: Lazy<PathBuf> = Lazy::new(|| {
17    let mut schema_path: PathBuf = std::env::var("CARGO_MANIFEST_DIR")
18        .expect("`CARGO_schema_path` must be set")
19        .into();
20
21    schema_path.push("sqlx-type-schema.sql");
22
23    if !schema_path.exists() {
24        use serde::Deserialize;
25        use std::process::Command;
26
27        let cargo = std::env::var("CARGO").expect("`CARGO` must be set");
28        schema_path.pop();
29
30        let output = Command::new(cargo)
31            .args(["metadata", "--format-version=1"])
32            .current_dir(&schema_path)
33            .env_remove("__CARGO_FIX_PLZ")
34            .output()
35            .expect("Could not fetch metadata");
36
37        #[derive(Deserialize)]
38        struct CargoMetadata {
39            workspace_root: PathBuf,
40        }
41
42        let metadata: CargoMetadata =
43            serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output");
44
45        schema_path = metadata.workspace_root;
46        schema_path.push("sqlx-type-schema.sql");
47    }
48    if !schema_path.exists() {
49        panic!("Unable to locate sqlx-type-schema.sql");
50    }
51    schema_path
52});
53
54// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
55// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
56static SCHEMA_SRC: Lazy<String> =
57    Lazy::new(|| match std::fs::read_to_string(SCHEMA_PATH.as_path()) {
58        Ok(v) => v,
59        Err(e) => panic!(
60            "Unable to read schema from {:?}: {}",
61            SCHEMA_PATH.as_path(),
62            e
63        ),
64    });
65
66fn issue_to_report(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
67    let mut builder = Report::build(
68        match issue.level {
69            sql_type::Level::Warning => ReportKind::Warning,
70            sql_type::Level::Error => ReportKind::Error,
71        },
72        issue.span.clone(),
73    )
74    .with_config(ariadne::Config::default().with_color(false))
75    .with_label(
76        Label::new(issue.span)
77            .with_order(-1)
78            .with_priority(-1)
79            .with_message(issue.message),
80    );
81    for frag in issue.fragments {
82        builder = builder.with_label(Label::new(frag.span).with_message(frag.message));
83    }
84    builder.finish()
85}
86
87fn issue_to_report_color(issue: Issue) -> Report<'static, std::ops::Range<usize>> {
88    let mut builder = Report::build(
89        match issue.level {
90            sql_type::Level::Warning => ReportKind::Warning,
91            sql_type::Level::Error => ReportKind::Error,
92        },
93        issue.span.clone(),
94    )
95    .with_label(
96        Label::new(issue.span)
97            .with_color(match issue.level {
98                sql_type::Level::Warning => Color::Yellow,
99                sql_type::Level::Error => Color::Red,
100            })
101            .with_order(-1)
102            .with_priority(-1)
103            .with_message(issue.message),
104    );
105    for frag in issue.fragments {
106        builder = builder.with_label(
107            Label::new(frag.span)
108                .with_color(Color::Blue)
109                .with_message(frag.message),
110        );
111    }
112    builder.finish()
113}
114
115struct NamedSource<'a>(&'a str, Source<&'a str>);
116
117impl<'a> ariadne::Cache<()> for &NamedSource<'a> {
118    type Storage = &'a str;
119
120    fn display<'b>(&self, _: &'b ()) -> Option<impl std::fmt::Display + 'b> {
121        Some(self.0.to_string())
122    }
123
124    fn fetch(&mut self, _: &()) -> Result<&Source<Self::Storage>, impl std::fmt::Debug> {
125        Ok::<_, ()>(&self.1)
126    }
127}
128
129static SCHEMAS: Lazy<(Schemas, SQLDialect)> = Lazy::new(|| {
130    let schema_src = SCHEMA_SRC.as_str();
131    let dialect = if let Some(first_line) = schema_src.lines().next() {
132        if first_line.contains("sql-product: postgres") {
133            SQLDialect::PostgreSQL
134        } else if first_line.contains("sql-product: sqlite") {
135            SQLDialect::Sqlite
136        } else {
137            SQLDialect::MariaDB
138        }
139    } else {
140        SQLDialect::MariaDB
141    };
142
143    let options = TypeOptions::new().dialect(dialect.clone());
144    let mut issues = sql_type::Issues::new(schema_src);
145    let schemas = parse_schemas(schema_src, &mut issues, &options);
146    if !issues.is_ok() {
147        let source = NamedSource("sqlx-type-schema.sql", Source::from(schema_src));
148        let mut err = false;
149        for issue in issues.into_vec() {
150            if issue.level == sql_type::Level::Error {
151                err = true;
152            }
153            let r = issue_to_report_color(issue);
154            r.eprint(&source).unwrap();
155        }
156        if err {
157            panic!("Errors processing sqlx-type-schema.sql");
158        }
159    }
160    (schemas, dialect)
161});
162
163fn quote_args(
164    errors: &mut Vec<proc_macro2::TokenStream>,
165    query: &str,
166    last_span: Span,
167    args: &[Expr],
168    arguments: &[(sql_type::ArgumentKey<'_>, sql_type::FullType)],
169    dialect: &SQLDialect,
170) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
171    let cls = match dialect {
172        SQLDialect::MariaDB => quote!(sqlx::mysql::MySql),
173        SQLDialect::Sqlite => quote!(sqlx::sqlite::Sqlite),
174        SQLDialect::PostgreSQL => quote!(sqlx::postgres::Postgres),
175    };
176
177    let mut at = Vec::new();
178    let inv = sql_type::FullType::invalid();
179    for (k, v) in arguments {
180        match k {
181            sql_type::ArgumentKey::Index(i) => {
182                while at.len() <= *i {
183                    at.push(&inv);
184                }
185                at[*i] = v;
186            }
187            sql_type::ArgumentKey::Identifier(_) => {
188                errors.push(
189                    syn::Error::new(last_span.span(), "Named arguments not supported")
190                        .to_compile_error(),
191                );
192            }
193        }
194    }
195
196    if at.len() > args.len() {
197        errors.push(
198            syn::Error::new(
199                last_span,
200                format!("Expected {} additional arguments", at.len() - args.len()),
201            )
202            .to_compile_error(),
203        );
204    }
205
206    if let Some(args) = args.get(at.len()..) {
207        for arg in args {
208            errors.push(syn::Error::new(arg.span(), "unexpected argument").to_compile_error());
209        }
210    }
211
212    let arg_names = (0..args.len())
213        .map(|i| format_ident!("arg{}", i))
214        .collect::<Vec<_>>();
215
216    let mut arg_bindings = Vec::new();
217    let mut arg_add = Vec::new();
218
219    let mut list_lengths = Vec::new();
220
221    for ((qa, ta), name) in args.iter().zip(at).zip(&arg_names) {
222        let mut t = match ta.t {
223            sql_type::Type::U8 => quote! {u8},
224            sql_type::Type::I8 => quote! {i8},
225            sql_type::Type::U16 => quote! {u16},
226            sql_type::Type::I16 => quote! {i16},
227            sql_type::Type::U32 => quote! {u32},
228            sql_type::Type::I32 => quote! {i32},
229            sql_type::Type::U64 => quote! {u64},
230            sql_type::Type::I64 => quote! {i64},
231            sql_type::Type::Base(sql_type::BaseType::Any) => quote! {sqlx_type::Any},
232            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
233            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {&[u8]},
234            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {sqlx_type::Date},
235            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {sqlx_type::DateTime},
236            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {sqlx_type::Float},
237            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {sqlx_type::Integer},
238            sql_type::Type::Base(sql_type::BaseType::String) => quote! {&str},
239            sql_type::Type::Base(sql_type::BaseType::Time) => quote! {sqlx_type::Time},
240            sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("time_interval"),
241            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => quote! {sqlx_type::Timestamp},
242            sql_type::Type::Null => todo!("null"),
243            sql_type::Type::Invalid => quote! {std::convert::Infallible},
244            sql_type::Type::Enum(_) => quote! {&str},
245            sql_type::Type::Set(_) => quote! {&str},
246            sql_type::Type::Args(_, _) => todo!("args"),
247            sql_type::Type::F32 => quote! {f32},
248            sql_type::Type::F64 => quote! {f64},
249            sql_type::Type::JSON => quote! {sqlx_type::Any},
250        };
251        if !ta.not_null {
252            t = quote! {Option<#t>}
253        }
254        let span = qa.span();
255        if ta.list_hack {
256            list_lengths.push(quote!(#name.len()));
257            arg_bindings.push(quote_spanned! {span=>
258                let #name = &(#qa);
259                args_count += #name.len();
260                for v in #name.iter() {
261                    size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(v);
262                }
263                if false {
264                    sqlx_type::check_arg_list_hack::<#t, _>(#name);
265                    ::std::panic!();
266                }
267            });
268            arg_add.push(quote!(
269                for v in #name.iter() {
270                    e = e.and_then(|()| query_args.add(v));
271                }
272            ));
273        } else {
274            arg_bindings.push(quote_spanned! {span=>
275                let #name = &(#qa);
276                args_count += 1;
277                size_hints += ::sqlx::encode::Encode::<#cls>::size_hint(#name);
278                if false {
279                    sqlx_type::check_arg::<#t, _>(#name);
280                    ::std::panic!();
281                }
282            });
283            arg_add.push(quote!(e = e.and_then(|()| query_args.add(#name));));
284        }
285    }
286
287    let query = if list_lengths.is_empty() {
288        quote!(#query)
289    } else {
290        quote!(
291            &sqlx_type::convert_list_query(#query, &[#(#list_lengths),*])
292        )
293    };
294
295    (
296        quote! {
297            let mut size_hints = 0;
298            let mut args_count = 0;
299            #(#arg_bindings)*
300
301            let mut query_args = <#cls as ::sqlx::database::Database>::Arguments::default();
302            query_args.reserve(args_count, size_hints);
303            let mut e = Ok(());
304            #(#arg_add)*
305            let query_args = e.and_then(|()| Ok(query_args));
306        },
307        query,
308    )
309}
310
311fn issues_to_errors(issues: Vec<Issue>, source: &str, span: Span) -> Vec<proc_macro2::TokenStream> {
312    if !issues.is_empty() {
313        let source = NamedSource("", Source::from(source));
314        let mut err = false;
315        let mut out = Vec::new();
316        for issue in issues {
317            if issue.level == sql_type::Level::Error {
318                err = true;
319            }
320            let r = issue_to_report(issue);
321            r.write(&source, &mut out).unwrap();
322        }
323        if err {
324            return vec![syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error()];
325        }
326    }
327    Vec::new()
328}
329
330fn construct_row(
331    columns: &[SelectTypeColumn],
332) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
333    let mut row_members = Vec::new();
334    let mut row_construct = Vec::new();
335    for (i, c) in columns.iter().enumerate() {
336        let mut t = match c.type_.t {
337            sql_type::Type::U8 => quote! {u8},
338            sql_type::Type::I8 => quote! {i8},
339            sql_type::Type::U16 => quote! {u16},
340            sql_type::Type::I16 => quote! {i16},
341            sql_type::Type::U32 => quote! {u32},
342            sql_type::Type::I32 => quote! {i32},
343            sql_type::Type::U64 => quote! {u64},
344            sql_type::Type::I64 => quote! {i64},
345            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
346            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
347            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
348            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
349            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
350            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
351            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
352            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
353            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
354            sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("from_time_interval"),
355            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
356                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
357            }
358            sql_type::Type::Null => todo!("from_null"),
359            sql_type::Type::Invalid => quote! {i64},
360            sql_type::Type::Enum(_) => quote! {String},
361            sql_type::Type::Set(_) => quote! {String},
362            sql_type::Type::Args(_, _) => todo!("from_args"),
363            sql_type::Type::F32 => quote! {f32},
364            sql_type::Type::F64 => quote! {f64},
365            sql_type::Type::JSON => quote! {String},
366        };
367        let name = match &c.name {
368            Some(v) => v,
369            None => continue,
370        };
371
372        let ident = String::from("r#") + name.value;
373        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
374            ident
375        } else {
376            // TODO error
377            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
378            continue;
379        };
380
381        if !c.type_.not_null {
382            t = quote! {Option<#t>};
383        }
384        row_members.push(quote! {
385            #ident : #t
386        });
387        row_construct.push(quote! {
388            #ident: sqlx::Row::get(&row, #i)
389        });
390    }
391    (row_members, row_construct)
392}
393
394struct Query {
395    query: String,
396    query_span: Span,
397    args: Vec<Expr>,
398    last_span: Span,
399}
400
401impl Parse for Query {
402    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
403        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
404        let query: String = query_.iter().map(LitStr::value).collect();
405        let query_span = query_.span();
406        let mut last_span = query_span;
407        let mut args = Vec::new();
408        while !input.is_empty() {
409            let _ = input.parse::<syn::token::Comma>()?;
410            if input.is_empty() {
411                break;
412            }
413            let arg = input.parse::<Expr>()?;
414            last_span = arg.span();
415            args.push(arg);
416        }
417        Ok(Self {
418            query,
419            query_span,
420            args,
421            last_span,
422        })
423    }
424}
425
426/// Statically checked SQL query, similarly to sqlx::query!.
427///
428/// This expands to an instance of query::Map that outputs an ad-hoc anonymous struct type.
429#[proc_macro]
430pub fn query(input: TokenStream) -> TokenStream {
431    let query = syn::parse_macro_input!(input as Query);
432    let (schemas, dialect) = SCHEMAS.deref();
433    let options = TypeOptions::new()
434        .dialect(dialect.clone())
435        .arguments(match &dialect {
436            SQLDialect::MariaDB => SQLArguments::QuestionMark,
437            SQLDialect::Sqlite => SQLArguments::QuestionMark,
438            SQLDialect::PostgreSQL => SQLArguments::Dollar,
439        })
440        .list_hack(true);
441    let mut issues = sql_type::Issues::new(&query.query);
442    let stmt = type_statement(schemas, &query.query, &mut issues, &options);
443    let sp = SCHEMA_PATH.as_path().to_str().unwrap();
444
445    let mut errors = issues_to_errors(issues.into_vec(), &query.query, query.query_span);
446    match &stmt {
447        sql_type::StatementType::Select { columns, arguments } => {
448            let (args_tokens, q) = quote_args(
449                &mut errors,
450                &query.query,
451                query.last_span,
452                &query.args,
453                arguments,
454                dialect,
455            );
456            let (row_members, row_construct) = construct_row(columns);
457            let s = quote! { {
458                use ::sqlx::Arguments as _;
459                let _ = std::include_bytes!(#sp);
460                #(#errors; )*
461                #args_tokens;
462
463                struct Row {
464                    #(#row_members),*
465                };
466                sqlx::__query_with_result(#q, query_args).map(|row|
467                    Row{
468                        #(#row_construct),*
469                    }
470                )
471            }};
472            s.into()
473        }
474        sql_type::StatementType::Delete {
475            arguments,
476            returning,
477        } => {
478            let (args_tokens, q) = quote_args(
479                &mut errors,
480                &query.query,
481                query.last_span,
482                &query.args,
483                arguments,
484                dialect,
485            );
486            let s = match returning.as_ref() {
487                Some(returning) => {
488                    let (row_members, row_construct) = construct_row(returning);
489                    quote! { {
490                        use ::sqlx::Arguments as _;
491                        let _ = std::include_bytes!(#sp);
492                        #(#errors; )*
493                        #args_tokens
494
495                        struct Row {
496                            #(#row_members),*
497                        };
498                        sqlx::__query_with_result(#q, query_args).map(|row|
499                            Row{
500                                #(#row_construct),*
501                            }
502                        )
503                    }}
504                }
505                None => quote! { {
506                    use ::sqlx::Arguments as _;
507                    #(#errors; )*
508                    #args_tokens
509                    sqlx::__query_with_result(#q, query_args)
510                }
511                },
512            };
513            s.into()
514        }
515        sql_type::StatementType::Insert {
516            arguments,
517            returning,
518            ..
519        } => {
520            let (args_tokens, q) = quote_args(
521                &mut errors,
522                &query.query,
523                query.last_span,
524                &query.args,
525                arguments,
526                dialect,
527            );
528            let s = match returning.as_ref() {
529                Some(returning) => {
530                    let (row_members, row_construct) = construct_row(returning);
531                    quote! { {
532                        use ::sqlx::Arguments as _;
533                        let _ = std::include_bytes!(#sp);
534                        #(#errors; )*
535                        #args_tokens
536
537                        struct Row {
538                            #(#row_members),*
539                        };
540                        sqlx::__query_with_result(#q, query_args).map(|row|
541                            Row{
542                                #(#row_construct),*
543                            }
544                        )
545                    }}
546                }
547                None => quote! { {
548                    use ::sqlx::Arguments as _;
549                    #(#errors; )*
550                    #args_tokens
551                    sqlx::__query_with_result(#q, query_args)
552                }
553                },
554            };
555            s.into()
556        }
557        sql_type::StatementType::Update {
558            arguments,
559            returning,
560        } => {
561            let (args_tokens, q) = quote_args(
562                &mut errors,
563                &query.query,
564                query.last_span,
565                &query.args,
566                arguments,
567                dialect,
568            );
569
570            let s = match returning.as_ref() {
571                Some(returning) => {
572                    let (row_members, row_construct) = construct_row(returning);
573                    quote! { {
574                        use ::sqlx::Arguments as _;
575                        let _ = std::include_bytes!(#sp);
576                        #(#errors; )*
577                        #args_tokens
578
579                        struct Row {
580                            #(#row_members),*
581                        };
582                        sqlx::__query_with_result(#q, query_args).map(|row|
583                            Row{
584                                #(#row_construct),*
585                            }
586                        )
587                    }}
588                }
589                None => quote! { {
590                    use ::sqlx::Arguments as _;
591                    #(#errors; )*
592                    #args_tokens
593                    sqlx::__query_with_result(#q, query_args)
594                }
595                },
596            };
597            s.into()
598        }
599        sql_type::StatementType::Replace {
600            arguments,
601            returning,
602        } => {
603            let (args_tokens, q) = quote_args(
604                &mut errors,
605                &query.query,
606                query.last_span,
607                &query.args,
608                arguments,
609                dialect,
610            );
611            let s = match returning.as_ref() {
612                Some(returning) => {
613                    let (row_members, row_construct) = construct_row(returning);
614                    quote! { {
615                        use ::sqlx::Arguments as _;
616                        let _ = std::include_bytes!(#sp);
617                        #(#errors; )*
618                        #args_tokens
619
620                        struct Row {
621                            #(#row_members),*
622                        };
623                        sqlx::__query_with_result(#q, query_args).map(|row|
624                            Row{
625                                #(#row_construct),*
626                            }
627                        )
628                    }}
629                }
630                None => quote! { {
631                    use ::sqlx::Arguments as _;
632                    #(#errors; )*
633                    #args_tokens
634                    sqlx::__query_with_result(#q, query_args)
635                }
636                },
637            };
638            s.into()
639        }
640        sql_type::StatementType::Invalid => {
641            let s = quote! { {
642                #(#errors; )*;
643                todo!("Invalid")
644            }};
645            s.into()
646        }
647    }
648}
649
650fn construct_row2(columns: &[SelectTypeColumn]) -> Vec<proc_macro2::TokenStream> {
651    let mut row_construct = Vec::new();
652    for (i, c) in columns.iter().enumerate() {
653        let mut t = match c.type_.t {
654            sql_type::Type::U8 => quote! {u8},
655            sql_type::Type::I8 => quote! {i8},
656            sql_type::Type::U16 => quote! {u16},
657            sql_type::Type::I16 => quote! {i16},
658            sql_type::Type::U32 => quote! {u32},
659            sql_type::Type::I32 => quote! {i32},
660            sql_type::Type::U64 => quote! {u64},
661            sql_type::Type::I64 => quote! {i64},
662            sql_type::Type::Base(sql_type::BaseType::Any) => todo!("from_any"),
663            sql_type::Type::Base(sql_type::BaseType::Bool) => quote! {bool},
664            sql_type::Type::Base(sql_type::BaseType::Bytes) => quote! {Vec<u8>},
665            sql_type::Type::Base(sql_type::BaseType::Date) => quote! {chrono::NaiveDate},
666            sql_type::Type::Base(sql_type::BaseType::DateTime) => quote! {chrono::NaiveDateTime},
667            sql_type::Type::Base(sql_type::BaseType::Float) => quote! {f64},
668            sql_type::Type::Base(sql_type::BaseType::Integer) => quote! {i64},
669            sql_type::Type::Base(sql_type::BaseType::String) => quote! {String},
670            sql_type::Type::Base(sql_type::BaseType::Time) => todo!("from_time"),
671            sql_type::Type::Base(sql_type::BaseType::TimeInterval) => todo!("from_time_interval"),
672            sql_type::Type::Base(sql_type::BaseType::TimeStamp) => {
673                quote! {sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>}
674            }
675            sql_type::Type::Null => todo!("from_null"),
676            sql_type::Type::Invalid => quote! {i64},
677            sql_type::Type::Enum(_) => quote! {String},
678            sql_type::Type::Set(_) => quote! {String},
679            sql_type::Type::Args(_, _) => todo!("from_args"),
680            sql_type::Type::F32 => quote! {f32},
681            sql_type::Type::F64 => quote! {f64},
682            sql_type::Type::JSON => quote! {String},
683        };
684        let name = match &c.name {
685            Some(v) => v,
686            None => continue,
687        };
688
689        let ident = String::from("r#") + name.value;
690        let ident: Ident = if let Ok(ident) = syn::parse_str(&ident) {
691            ident
692        } else {
693            // TODO error
694            //errors.push(syn::Error::new(span, String::from_utf8(out).unwrap()).to_compile_error().into());
695            continue;
696        };
697
698        if !c.type_.not_null {
699            t = quote! {Option<#t>};
700        }
701        row_construct.push(quote! {
702            #ident: sqlx_type::arg_out::<#t, _, #i>(sqlx::Row::get(&row, #i))
703        });
704    }
705    row_construct
706}
707
708struct QueryAs {
709    as_: Ident,
710    query: String,
711    query_span: Span,
712    args: Vec<Expr>,
713    last_span: Span,
714}
715
716impl Parse for QueryAs {
717    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
718        let as_ = input.parse::<Ident>()?;
719        let _ = input.parse::<syn::token::Comma>()?;
720
721        let query_ = Punctuated::<LitStr, Token![+]>::parse_separated_nonempty(input)?;
722        let query: String = query_.iter().map(LitStr::value).collect();
723        let query_span = query_.span();
724
725        let mut last_span = query_span;
726        let mut args = Vec::new();
727        while !input.is_empty() {
728            let _ = input.parse::<syn::token::Comma>()?;
729            if input.is_empty() {
730                break;
731            }
732            let arg = input.parse::<Expr>()?;
733            last_span = arg.span();
734            args.push(arg);
735        }
736        Ok(Self {
737            as_,
738            query,
739            query_span,
740            args,
741            last_span,
742        })
743    }
744}
745
746/// A variant of query! which takes a path to an explicitly defined struct as the output type.
747///
748/// This lets you return the struct from a function or add your own trait implementations.
749#[proc_macro]
750pub fn query_as(input: TokenStream) -> TokenStream {
751    let query_as = syn::parse_macro_input!(input as QueryAs);
752    let (schemas, dialect) = SCHEMAS.deref();
753    let options = TypeOptions::new()
754        .dialect(dialect.clone())
755        .arguments(match &dialect {
756            SQLDialect::MariaDB => SQLArguments::QuestionMark,
757            SQLDialect::Sqlite => SQLArguments::QuestionMark,
758            SQLDialect::PostgreSQL => SQLArguments::Dollar,
759        })
760        .list_hack(true);
761    let mut issues = sql_type::Issues::new(&query_as.query);
762    let stmt = type_statement(schemas, &query_as.query, &mut issues, &options);
763
764    let mut errors = issues_to_errors(issues.into_vec(), &query_as.query, query_as.query_span);
765    match &stmt {
766        sql_type::StatementType::Select { columns, arguments } => {
767            let (args_tokens, q) = quote_args(
768                &mut errors,
769                &query_as.query,
770                query_as.last_span,
771                &query_as.args,
772                arguments,
773                dialect,
774            );
775
776            let row_construct = construct_row2(columns);
777            let row = query_as.as_;
778            let s = quote! { {
779                use ::sqlx::Arguments as _;
780                #(#errors; )*
781                #args_tokens
782                sqlx::__query_with_result(#q, query_args).map(|row|
783                    #row{
784                        #(#row_construct),*
785                    }
786                )
787            }};
788            //println!("TOKENS: {}", s);
789            s.into()
790        }
791        sql_type::StatementType::Delete { .. } => {
792            errors.push(
793                syn::Error::new(query_as.query_span, "DELETE not support in query_as")
794                    .to_compile_error(),
795            );
796            quote! { {
797                #(#errors; )*
798                todo!("delete")
799            }}
800            .into()
801        }
802        sql_type::StatementType::Insert {
803            returning: None, ..
804        } => {
805            errors.push(
806                syn::Error::new(
807                    query_as.query_span,
808                    "INSERT without RETURNING not support in query_as",
809                )
810                .to_compile_error(),
811            );
812            quote! { {
813                #(#errors; )*
814                todo!("insert")
815            }}
816            .into()
817        }
818        sql_type::StatementType::Insert {
819            arguments,
820            returning: Some(returning),
821            ..
822        } => {
823            let (args_tokens, q) = quote_args(
824                &mut errors,
825                &query_as.query,
826                query_as.last_span,
827                &query_as.args,
828                arguments,
829                dialect,
830            );
831
832            let row_construct = construct_row2(returning);
833            let row = query_as.as_;
834            let s = quote! { {
835                use ::sqlx::Arguments as _;
836                #(#errors; )*
837                #args_tokens
838                sqlx::__query_with_result(#q, query_args).map(|row|
839                    #row{
840                        #(#row_construct),*
841                    }
842                )
843            }};
844            s.into()
845        }
846        sql_type::StatementType::Update { .. } => {
847            errors.push(
848                syn::Error::new(query_as.query_span, "UPDATE not support in query_as")
849                    .to_compile_error(),
850            );
851            quote! { {
852                #(#errors; )*
853                todo!("update")
854            }}
855            .into()
856        }
857        sql_type::StatementType::Replace {
858            returning: None, ..
859        } => {
860            errors.push(
861                syn::Error::new(
862                    query_as.query_span,
863                    "REPLACE without RETURNING not support in query_as",
864                )
865                .to_compile_error(),
866            );
867            quote! { {
868                #(#errors; )*
869                todo!("replace")
870            }}
871            .into()
872        }
873        sql_type::StatementType::Replace {
874            arguments,
875            returning: Some(returning),
876            ..
877        } => {
878            let (args_tokens, q) = quote_args(
879                &mut errors,
880                &query_as.query,
881                query_as.last_span,
882                &query_as.args,
883                arguments,
884                dialect,
885            );
886
887            let row_construct = construct_row2(returning);
888            let row = query_as.as_;
889            let s = quote! { {
890                use ::sqlx::Arguments as _;
891                #(#errors; )*
892                #args_tokens
893                sqlx::__query_with_result(#q, query_args).map(|row|
894                    #row{
895                        #(#row_construct),*
896                    }
897                )
898            }};
899            s.into()
900        }
901        sql_type::StatementType::Invalid => quote! { {
902            #(#errors; )*;
903            todo!("invalid")
904        }}
905        .into(),
906    }
907}