rust_query_macros/
lib.rs

1use std::{collections::BTreeMap, ops::Not};
2
3use dummy::from_row_impl;
4use heck::{ToSnekCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use syn::{
8    punctuated::Punctuated, Attribute, Ident, ItemEnum, ItemStruct, Meta, Path, Token, Type,
9};
10
11mod dummy;
12mod table;
13
14/// Use this macro to define your schema.
15///
16/// ## Supported data types:
17/// - `i64` (sqlite `integer`)
18/// - `f64` (sqlite `real`)
19/// - `String` (sqlite `text`)
20/// - Any table in the same schema (sqlite `integer` with foreign key constraint)
21/// - `Option<T>` where `T` is not an `Option` (sqlite nullable)
22///
23/// Booleans are not supported in schemas yet.
24///
25/// ## Unique constraints
26///
27/// For example:
28/// ```
29/// #[rust_query::migration::schema]
30/// #[version(0..=0)]
31/// enum Schema {
32///     User {
33///         #[unique_email]
34///         email: String,
35///         #[unique_username]
36///         username: String,
37///     }
38/// }
39/// # fn main() {}
40/// ```
41/// This will create a single schema with a single table called `user` and two columns.
42/// The table will also have two unique contraints.
43///
44/// To define a unique constraint on a column, you need to add an attribute to that field.
45/// The attribute needs to start with `unique` and can have any suffix.
46/// Within a table, the different unique constraints must have different suffixes.
47///
48/// Optional types are not allowed in unique constraints.
49///
50/// ## Multiple versions
51/// The macro uses enum syntax, but it generates multiple modules of types.
52///
53/// Note that the schema version range is `0..=0` so there is only a version 0.
54/// The generated code will have a structure like this:
55/// ```rust,ignore
56/// mod v0 {
57///     struct User(..);
58///     // a bunch of other stuff
59/// }
60/// ```
61///
62/// # Adding tables
63/// At some point you might want to add a new table.
64/// ```
65/// #[rust_query::migration::schema]
66/// #[version(0..=1)]
67/// enum Schema {
68///     User {
69///         #[unique_email]
70///         email: String,
71///         #[unique_username]
72///         username: String,
73///     },
74///     #[version(1..)] // <-- note that `Game`` has a version range
75///     Game {
76///         name: String,
77///         size: i64,
78///     }
79/// }
80/// # fn main() {}
81/// ```
82/// We now have two schema versions which generates two modules `v0` and `v1`.
83/// They look something like this:
84/// ```rust,ignore
85/// mod v0 {
86///     struct User(..);
87///     // a bunch of other stuff
88/// }
89/// mod v1 {
90///     struct User(..);
91///     struct Game(..);
92///     // a bunch of other stuff
93/// }
94/// ```
95///
96/// # Changing columns
97/// Changing columns is very similar to adding and removing structs.
98/// ```
99/// use rust_query::migration::{schema, Config, Alter};
100/// use rust_query::{Dummy, LocalClient, Database};
101/// #[schema]
102/// #[version(0..=1)]
103/// enum Schema {
104///     User {
105///         #[unique_email]
106///         email: String,
107///         #[unique_username]
108///         username: String,
109///         #[version(1..)] // <-- here
110///         score: i64,
111///     },
112/// }
113/// // In this case it is required to provide a value for each row that already exists.
114/// // This is done with the `v1::update::UserMigration`:
115/// pub fn migrate(client: &mut LocalClient) -> Database<v1::Schema> {
116///     let m = client.migrator(Config::open_in_memory()) // we use an in memory database for this test
117///         .expect("database version is before supported versions");
118///     let m = m.migrate(v1::update::Schema {
119///         user: Box::new(|user|
120///             Alter::new(v1::update::UserMigration {
121///                 score: user.email().map_dummy(|x| x.len() as i64) // use the email length as the new score
122///             })
123///         ),
124///     });
125///     m.finish().expect("database version is after supported versions")
126/// }
127/// # fn main() {}
128/// ```
129/// The `migrate` function first creates an empty database if it does not exists.
130/// Then it migrates the database if necessary, where it initializes every user score to the length of their email.
131///
132/// # Other features
133/// You can delete columns and tables by specifying the version range end.
134/// ```rust,ignore
135/// #[version(..3)]
136/// ```
137/// You can make a multi column unique constraint by specifying it before the table.
138/// ```rust,ignore
139/// #[unique(user, game)]
140/// UserGameStats {
141///     user: User,
142///     game: Game,
143///     score: i64,
144/// }
145/// ```
146#[proc_macro_attribute]
147pub fn schema(
148    attr: proc_macro::TokenStream,
149    item: proc_macro::TokenStream,
150) -> proc_macro::TokenStream {
151    assert!(attr.is_empty());
152    let item = syn::parse_macro_input!(item as ItemEnum);
153
154    match generate(item) {
155        Ok(x) => x,
156        Err(e) => e.into_compile_error(),
157    }
158    .into()
159}
160
161/// Derive [FromDummy] to create a new `*Dummy` struct.
162///
163/// This `*Dummy` struct can then be used with [Query::into_vec] or [Transaction::query_one].
164/// Usage can also be nested.
165///
166/// Note that the result of [Query::into_vec] is sorted. When a `*Dummy` struct is used for
167/// the output, the sorting order depends on the order of the fields in the struct definition.
168///
169/// Example:
170/// ```
171/// #[rust_query::migration::schema]
172/// pub enum Schema {
173///     Thing {
174///         details: Details,
175///         beta: f64,
176///         seconds: i64,
177///     },
178///     Details {
179///         name: String
180///     },
181/// }
182/// use v0::*;
183/// use rust_query::{Table, FromDummy, Transaction};
184///
185/// #[derive(FromDummy)]
186/// struct MyData {
187///     seconds: i64,
188///     is_it_real: bool,
189///     name: String,
190///     other: OtherData
191/// }
192///
193/// #[derive(FromDummy)]
194/// struct OtherData {
195///     alpha: f64,
196///     beta: f64,
197/// }
198///
199/// pub fn do_query(db: &Transaction<Schema>) -> Vec<MyData> {
200///     db.query(|rows| {
201///         let thing = Thing::join(rows);
202///
203///         rows.into_vec(MyDataDummy {
204///             seconds: thing.seconds(),
205///             is_it_real: true,
206///             name: thing.details().name(),
207///             other: OtherDataDummy {
208///                 alpha: 0.5,
209///                 beta: thing.beta(),
210///             },
211///         })
212///     })
213/// }
214/// ```
215#[proc_macro_derive(FromDummy)]
216pub fn from_row(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217    let item = syn::parse_macro_input!(item as ItemStruct);
218    match from_row_impl(item) {
219        Ok(x) => x,
220        Err(e) => e.into_compile_error(),
221    }
222    .into()
223}
224
225#[derive(Clone)]
226struct Table {
227    referer: bool,
228    uniques: Vec<Unique>,
229    prev: Option<Ident>,
230    name: Ident,
231    columns: BTreeMap<usize, Column>,
232}
233
234#[derive(Clone)]
235struct Unique {
236    name: Ident,
237    columns: Vec<Ident>,
238}
239
240#[derive(Clone)]
241struct Column {
242    name: Ident,
243    typ: Type,
244}
245
246#[derive(Clone)]
247struct Range {
248    start: u32,
249    end: Option<RangeEnd>,
250}
251
252#[derive(Clone)]
253struct RangeEnd {
254    inclusive: bool,
255    num: u32,
256}
257
258impl RangeEnd {
259    pub fn end_exclusive(&self) -> u32 {
260        match self.inclusive {
261            true => self.num + 1,
262            false => self.num,
263        }
264    }
265}
266
267impl Range {
268    pub fn includes(&self, idx: u32) -> bool {
269        if idx < self.start {
270            return false;
271        }
272        if let Some(end) = &self.end {
273            return idx < end.end_exclusive();
274        }
275        true
276    }
277}
278
279impl syn::parse::Parse for Range {
280    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
281        let start: Option<syn::LitInt> = input.parse()?;
282        let _: Token![..] = input.parse()?;
283        let end: Option<RangeEnd> = input.is_empty().not().then(|| input.parse()).transpose()?;
284
285        let res = Range {
286            start: start
287                .map(|x| x.base10_parse())
288                .transpose()?
289                .unwrap_or_default(),
290            end,
291        };
292        Ok(res)
293    }
294}
295
296impl syn::parse::Parse for RangeEnd {
297    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
298        let equals: Option<Token![=]> = input.parse()?;
299        let end: syn::LitInt = input.parse()?;
300
301        let res = RangeEnd {
302            inclusive: equals.is_some(),
303            num: end.base10_parse()?,
304        };
305        Ok(res)
306    }
307}
308
309fn parse_version(attrs: &[Attribute]) -> syn::Result<Range> {
310    let mut version = None;
311    for attr in attrs {
312        if attr.path().is_ident("version") {
313            if version.is_some() {
314                return Err(syn::Error::new_spanned(
315                    attr,
316                    "There should be only one version attribute.",
317                ));
318            }
319            version = Some(attr.parse_args()?);
320        } else {
321            return Err(syn::Error::new_spanned(attr, "unexpected attribute"));
322        }
323    }
324    Ok(version.unwrap_or(Range {
325        start: 0,
326        end: None,
327    }))
328}
329
330fn make_generic(name: &Ident) -> Ident {
331    let normalized = name.to_string().to_upper_camel_case();
332    format_ident!("_{normalized}")
333}
334
335fn to_lower(name: &Ident) -> Ident {
336    let normalized = name.to_string().to_snek_case();
337    format_ident!("{normalized}")
338}
339
340// prev_table is only used for the columns
341fn define_table_migration(
342    prev_columns: Option<&BTreeMap<usize, Column>>,
343    table: &Table,
344) -> Option<TokenStream> {
345    let mut defs = vec![];
346    let mut into_new = vec![];
347    let mut generics = vec![];
348    let mut bounds = vec![];
349    let mut prepare = vec![];
350    let prev_columns_uwrapped = prev_columns.unwrap_or(const { &BTreeMap::new() });
351
352    for (i, col) in &table.columns {
353        let name = &col.name;
354        let prepared_name = format_ident!("prepared_{name}");
355        let name_str = col.name.to_string();
356        let typ = &col.typ;
357        let generic = make_generic(name);
358        if prev_columns_uwrapped.contains_key(i) {
359            into_new.push(quote! {reader.col(#name_str, prev.#name())});
360        } else {
361            defs.push(quote! {pub #name: #generic});
362            bounds.push(quote! {#generic: 't + ::rust_query::Dummy<'t, 'a, _PrevSchema, Out = <#typ as ::rust_query::private::MyTyp>::Out<'a>>});
363            generics.push(generic);
364            prepare.push(
365                quote! {let mut #prepared_name = ::rust_query::Dummy::prepare(self.#name, cacher)},
366            );
367            into_new.push(quote! {reader.col(#name_str, #prepared_name(row))});
368        }
369    }
370
371    // check that nothing was added or removed
372    // we don't need input if only stuff was removed, but it still needs migrating
373    if defs.is_empty() && table.columns.len() == prev_columns_uwrapped.len() {
374        return None;
375    }
376
377    let table_name = &table.name;
378    let migration_name = format_ident!("{table_name}Migration");
379    let prev_typ = quote! {#table_name};
380
381    let trait_impl = if prev_columns.is_some() {
382        quote! {
383            impl<'t, 'a #(,#bounds)*> ::rust_query::private::TableMigration<'t, 'a> for #migration_name<#(#generics),*> {
384                type From = #prev_typ;
385                type To = super::#table_name;
386
387                fn prepare(
388                    self: Box<Self>,
389                    prev: ::rust_query::private::Cached<'t, Self::From>,
390                    cacher: ::rust_query::private::Cacher<'_, 't, <Self::From as ::rust_query::Table>::Schema>,
391                ) -> Box<
392                    dyn FnMut(::rust_query::private::Row<'_, 't, 'a>, ::rust_query::private::Reader<'_, 't, <Self::From as ::rust_query::Table>::Schema>) + 't,
393                >
394                where
395                    'a: 't
396                {
397                    #(#prepare;)*
398                    Box::new(move |row, reader| {
399                        let prev = row.get(prev);
400                        #(#into_new;)*
401                    })
402                }
403            }
404        }
405    } else {
406        quote! {
407            impl<'t, 'a #(,#bounds)*> ::rust_query::private::TableCreation<'t, 'a> for #migration_name<#(#generics),*>{
408                type FromSchema = _PrevSchema;
409                type To = super::#table_name;
410
411                fn prepare(
412                    self: Box<Self>,
413                    cacher: ::rust_query::private::Cacher<'_, 't, Self::FromSchema>,
414                ) -> Box<
415                    dyn FnMut(::rust_query::private::Row<'_, 't, 'a>, ::rust_query::private::Reader<'_, 't, Self::FromSchema>) + 't,
416                >
417                where
418                    'a: 't
419                {
420                    #(#prepare;)*
421                    Box::new(move |row, reader| {
422                        #(#into_new;)*
423                    })
424                }
425            }
426        }
427    };
428
429    let migration = quote! {
430        pub struct #migration_name<#(#generics),*> {
431            #(#defs,)*
432        }
433
434        #trait_impl
435    };
436    Some(migration)
437}
438
439fn is_unique(path: &Path) -> Option<Ident> {
440    path.get_ident().and_then(|ident| {
441        ident
442            .to_string()
443            .starts_with("unique")
444            .then(|| ident.clone())
445    })
446}
447
448fn generate(item: ItemEnum) -> syn::Result<TokenStream> {
449    let range = parse_version(&item.attrs)?;
450    let schema = &item.ident;
451
452    let mut output = TokenStream::new();
453    let mut prev_tables: BTreeMap<usize, Table> = BTreeMap::new();
454    let mut prev_mod = None;
455    let range_end = range.end.map(|x| x.end_exclusive()).unwrap_or(1);
456    for version in range.start..range_end {
457        let mut new_tables: BTreeMap<usize, Table> = BTreeMap::new();
458
459        let mut mod_output = TokenStream::new();
460        for (i, table) in item.variants.iter().enumerate() {
461            let mut other_attrs = vec![];
462            let mut uniques = vec![];
463            let mut referer = true;
464            for attr in &table.attrs {
465                if let Some(unique) = is_unique(attr.path()) {
466                    let idents = attr.parse_args_with(
467                        Punctuated::<Ident, Token![,]>::parse_separated_nonempty,
468                    )?;
469                    uniques.push(Unique {
470                        name: unique,
471                        columns: idents.into_iter().collect(),
472                    })
473                } else if attr.path().is_ident("no_reference") {
474                    // `no_reference` only applies to the last version of the schema.
475                    if version + 1 == range_end {
476                        referer = false;
477                    }
478                } else {
479                    other_attrs.push(attr.clone());
480                }
481            }
482
483            let range = parse_version(&other_attrs)?;
484            if !range.includes(version) {
485                continue;
486            }
487            let mut prev = None;
488            // if this is not the first schema version where this table exists
489            if version != range.start {
490                // the previous name of this table is the current name
491                prev = Some(table.ident.clone());
492            }
493
494            let mut columns = BTreeMap::new();
495            for (i, field) in table.fields.iter().enumerate() {
496                let Some(name) = field.ident.clone() else {
497                    return Err(syn::Error::new_spanned(
498                        field,
499                        "Expected table columns to be named.",
500                    ));
501                };
502                // not sure if case matters here
503                if name.to_string().to_lowercase() == "id" {
504                    return Err(syn::Error::new_spanned(
505                        name,
506                        "The `id` column is reserved to be used by rust-query internally.",
507                    ));
508                }
509                let mut other_attrs = vec![];
510                let mut unique = None;
511                for attr in &field.attrs {
512                    if let Some(unique_name) = is_unique(attr.path()) {
513                        let Meta::Path(_) = &attr.meta else {
514                            return Err(syn::Error::new_spanned(
515                                attr,
516                                "Expected no arguments for field specific unique attribute.",
517                            ));
518                        };
519                        unique = Some(Unique {
520                            name: unique_name,
521                            columns: vec![name.clone()],
522                        })
523                    } else {
524                        other_attrs.push(attr.clone());
525                    }
526                }
527                let range = parse_version(&other_attrs)?;
528                if !range.includes(version) {
529                    continue;
530                }
531                let col = Column {
532                    name,
533                    typ: field.ty.clone(),
534                };
535                columns.insert(i, col);
536                uniques.extend(unique);
537            }
538
539            let table = Table {
540                referer,
541                prev,
542                name: table.ident.clone(),
543                columns,
544                uniques,
545            };
546
547            mod_output.extend(table::define_table(&table, schema)?);
548            new_tables.insert(i, table);
549        }
550
551        let mut schema_table_typs = vec![];
552
553        let mut table_defs = vec![];
554        let mut tables = vec![];
555
556        let mut table_migrations = TokenStream::new();
557
558        // loop over all new table and see what changed
559        for (i, table) in &new_tables {
560            let table_name = &table.name;
561
562            let table_lower = to_lower(table_name);
563
564            schema_table_typs.push(quote! {b.table::<#table_name>()});
565
566            if let Some(prev_table) = prev_tables.remove(i) {
567                // a table already existed, so we need to define a migration
568
569                let Some(migration) = define_table_migration(Some(&prev_table.columns), table)
570                else {
571                    continue;
572                };
573                table_migrations.extend(migration);
574
575                table_defs.push(quote! {
576                    pub #table_lower: ::rust_query::private::M<'t, #table_name, super::#table_name>
577                });
578                tables.push(quote! {b.migrate_table(self.#table_lower)});
579            } else {
580                let Some(migration) = define_table_migration(None, table) else {
581                    return Err(syn::Error::new_spanned(
582                        &table.name,
583                        "Empty tables are not supported (yet).",
584                    ));
585                };
586                table_migrations.extend(migration);
587
588                table_defs.push(quote! {
589                    pub #table_lower: ::rust_query::private::C<'t, _PrevSchema, super::#table_name>
590                });
591                tables.push(quote! {b.create_from(self.#table_lower)});
592            }
593        }
594        for prev_table in prev_tables.into_values() {
595            // a table was removed, so we drop it
596
597            let table_ident = &prev_table.name;
598            tables.push(quote! {b.drop_table::<super::super::#prev_mod::#table_ident>()})
599        }
600
601        let version_i64 = version as i64;
602        mod_output.extend(quote! {
603            pub struct #schema;
604            impl ::rust_query::private::Schema for #schema {
605                const VERSION: i64 = #version_i64;
606
607                fn typs(b: &mut ::rust_query::private::TableTypBuilder<Self>) {
608                    #(#schema_table_typs;)*
609                }
610            }
611
612            pub fn assert_hash(expect: ::rust_query::private::Expect) {
613                expect.assert_eq(&::rust_query::private::hash_schema::<#schema>())
614            }
615        });
616
617        let new_mod = format_ident!("v{version}");
618
619        let migrations = prev_mod.map(|prev_mod| {
620            let prelude = prelude(&new_tables, &prev_mod, schema);
621
622            let lifetime = table_defs.is_empty().not().then_some(quote! {'t,});
623            quote! {
624                pub mod update {
625                    #prelude
626
627                    #table_migrations
628
629                    pub struct #schema<#lifetime> {
630                        #(#table_defs,)*
631                    }
632
633                    impl<'t> ::rust_query::private::Migration<'t> for #schema<#lifetime> {
634                        type From = _PrevSchema;
635                        type To = super::#schema;
636
637                        fn tables(self, b: &mut ::rust_query::private::SchemaBuilder<'_, 't>) {
638                            #(#tables;)*
639                        }
640                    }
641                }
642            }
643        });
644
645        output.extend(quote! {
646            mod #new_mod {
647                #mod_output
648
649                #migrations
650            }
651        });
652
653        prev_tables = new_tables;
654        prev_mod = Some(new_mod);
655    }
656
657    Ok(output)
658}
659
660fn prelude(new_tables: &BTreeMap<usize, Table>, prev_mod: &Ident, schema: &Ident) -> TokenStream {
661    let mut prelude = vec![];
662    for table in new_tables.values() {
663        let Some(old_name) = &table.prev else {
664            continue;
665        };
666        let new_name = &table.name;
667        prelude.push(quote! {
668            #old_name as #new_name
669        });
670    }
671    prelude.push(quote! {#schema as _PrevSchema});
672    let mut prelude = quote! {
673        #[allow(unused_imports)]
674        use super::super::#prev_mod::{#(#prelude,)*};
675    };
676    for table in new_tables.values() {
677        if table.prev.is_none() {
678            let new_name = &table.name;
679            prelude.extend(quote! {
680                #[allow(unused_imports)]
681                use ::rust_query::migration::NoTable as #new_name;
682            })
683        }
684    }
685    prelude
686}