1use std::{collections::BTreeMap, ops::Not};
2
3use dummy::from_row_impl;
4use heck::{ToSnekCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use syn::{
8 punctuated::Punctuated, Attribute, Ident, ItemEnum, ItemStruct, Meta, Path, Token, Type,
9};
10
11mod dummy;
12mod table;
13
14#[proc_macro_attribute]
147pub fn schema(
148 attr: proc_macro::TokenStream,
149 item: proc_macro::TokenStream,
150) -> proc_macro::TokenStream {
151 assert!(attr.is_empty());
152 let item = syn::parse_macro_input!(item as ItemEnum);
153
154 match generate(item) {
155 Ok(x) => x,
156 Err(e) => e.into_compile_error(),
157 }
158 .into()
159}
160
161#[proc_macro_derive(FromDummy)]
216pub fn from_row(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217 let item = syn::parse_macro_input!(item as ItemStruct);
218 match from_row_impl(item) {
219 Ok(x) => x,
220 Err(e) => e.into_compile_error(),
221 }
222 .into()
223}
224
225#[derive(Clone)]
226struct Table {
227 referer: bool,
228 uniques: Vec<Unique>,
229 prev: Option<Ident>,
230 name: Ident,
231 columns: BTreeMap<usize, Column>,
232}
233
234#[derive(Clone)]
235struct Unique {
236 name: Ident,
237 columns: Vec<Ident>,
238}
239
240#[derive(Clone)]
241struct Column {
242 name: Ident,
243 typ: Type,
244}
245
246#[derive(Clone)]
247struct Range {
248 start: u32,
249 end: Option<RangeEnd>,
250}
251
252#[derive(Clone)]
253struct RangeEnd {
254 inclusive: bool,
255 num: u32,
256}
257
258impl RangeEnd {
259 pub fn end_exclusive(&self) -> u32 {
260 match self.inclusive {
261 true => self.num + 1,
262 false => self.num,
263 }
264 }
265}
266
267impl Range {
268 pub fn includes(&self, idx: u32) -> bool {
269 if idx < self.start {
270 return false;
271 }
272 if let Some(end) = &self.end {
273 return idx < end.end_exclusive();
274 }
275 true
276 }
277}
278
279impl syn::parse::Parse for Range {
280 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
281 let start: Option<syn::LitInt> = input.parse()?;
282 let _: Token![..] = input.parse()?;
283 let end: Option<RangeEnd> = input.is_empty().not().then(|| input.parse()).transpose()?;
284
285 let res = Range {
286 start: start
287 .map(|x| x.base10_parse())
288 .transpose()?
289 .unwrap_or_default(),
290 end,
291 };
292 Ok(res)
293 }
294}
295
296impl syn::parse::Parse for RangeEnd {
297 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
298 let equals: Option<Token![=]> = input.parse()?;
299 let end: syn::LitInt = input.parse()?;
300
301 let res = RangeEnd {
302 inclusive: equals.is_some(),
303 num: end.base10_parse()?,
304 };
305 Ok(res)
306 }
307}
308
309fn parse_version(attrs: &[Attribute]) -> syn::Result<Range> {
310 let mut version = None;
311 for attr in attrs {
312 if attr.path().is_ident("version") {
313 if version.is_some() {
314 return Err(syn::Error::new_spanned(
315 attr,
316 "There should be only one version attribute.",
317 ));
318 }
319 version = Some(attr.parse_args()?);
320 } else {
321 return Err(syn::Error::new_spanned(attr, "unexpected attribute"));
322 }
323 }
324 Ok(version.unwrap_or(Range {
325 start: 0,
326 end: None,
327 }))
328}
329
330fn make_generic(name: &Ident) -> Ident {
331 let normalized = name.to_string().to_upper_camel_case();
332 format_ident!("_{normalized}")
333}
334
335fn to_lower(name: &Ident) -> Ident {
336 let normalized = name.to_string().to_snek_case();
337 format_ident!("{normalized}")
338}
339
340fn define_table_migration(
342 prev_columns: Option<&BTreeMap<usize, Column>>,
343 table: &Table,
344) -> Option<TokenStream> {
345 let mut defs = vec![];
346 let mut into_new = vec![];
347 let mut generics = vec![];
348 let mut bounds = vec![];
349 let mut prepare = vec![];
350 let prev_columns_uwrapped = prev_columns.unwrap_or(const { &BTreeMap::new() });
351
352 for (i, col) in &table.columns {
353 let name = &col.name;
354 let prepared_name = format_ident!("prepared_{name}");
355 let name_str = col.name.to_string();
356 let typ = &col.typ;
357 let generic = make_generic(name);
358 if prev_columns_uwrapped.contains_key(i) {
359 into_new.push(quote! {reader.col(#name_str, prev.#name())});
360 } else {
361 defs.push(quote! {pub #name: #generic});
362 bounds.push(quote! {#generic: 't + ::rust_query::Dummy<'t, 'a, _PrevSchema, Out = <#typ as ::rust_query::private::MyTyp>::Out<'a>>});
363 generics.push(generic);
364 prepare.push(
365 quote! {let mut #prepared_name = ::rust_query::Dummy::prepare(self.#name, cacher)},
366 );
367 into_new.push(quote! {reader.col(#name_str, #prepared_name(row))});
368 }
369 }
370
371 if defs.is_empty() && table.columns.len() == prev_columns_uwrapped.len() {
374 return None;
375 }
376
377 let table_name = &table.name;
378 let migration_name = format_ident!("{table_name}Migration");
379 let prev_typ = quote! {#table_name};
380
381 let trait_impl = if prev_columns.is_some() {
382 quote! {
383 impl<'t, 'a #(,#bounds)*> ::rust_query::private::TableMigration<'t, 'a> for #migration_name<#(#generics),*> {
384 type From = #prev_typ;
385 type To = super::#table_name;
386
387 fn prepare(
388 self: Box<Self>,
389 prev: ::rust_query::private::Cached<'t, Self::From>,
390 cacher: ::rust_query::private::Cacher<'_, 't, <Self::From as ::rust_query::Table>::Schema>,
391 ) -> Box<
392 dyn FnMut(::rust_query::private::Row<'_, 't, 'a>, ::rust_query::private::Reader<'_, 't, <Self::From as ::rust_query::Table>::Schema>) + 't,
393 >
394 where
395 'a: 't
396 {
397 #(#prepare;)*
398 Box::new(move |row, reader| {
399 let prev = row.get(prev);
400 #(#into_new;)*
401 })
402 }
403 }
404 }
405 } else {
406 quote! {
407 impl<'t, 'a #(,#bounds)*> ::rust_query::private::TableCreation<'t, 'a> for #migration_name<#(#generics),*>{
408 type FromSchema = _PrevSchema;
409 type To = super::#table_name;
410
411 fn prepare(
412 self: Box<Self>,
413 cacher: ::rust_query::private::Cacher<'_, 't, Self::FromSchema>,
414 ) -> Box<
415 dyn FnMut(::rust_query::private::Row<'_, 't, 'a>, ::rust_query::private::Reader<'_, 't, Self::FromSchema>) + 't,
416 >
417 where
418 'a: 't
419 {
420 #(#prepare;)*
421 Box::new(move |row, reader| {
422 #(#into_new;)*
423 })
424 }
425 }
426 }
427 };
428
429 let migration = quote! {
430 pub struct #migration_name<#(#generics),*> {
431 #(#defs,)*
432 }
433
434 #trait_impl
435 };
436 Some(migration)
437}
438
439fn is_unique(path: &Path) -> Option<Ident> {
440 path.get_ident().and_then(|ident| {
441 ident
442 .to_string()
443 .starts_with("unique")
444 .then(|| ident.clone())
445 })
446}
447
448fn generate(item: ItemEnum) -> syn::Result<TokenStream> {
449 let range = parse_version(&item.attrs)?;
450 let schema = &item.ident;
451
452 let mut output = TokenStream::new();
453 let mut prev_tables: BTreeMap<usize, Table> = BTreeMap::new();
454 let mut prev_mod = None;
455 let range_end = range.end.map(|x| x.end_exclusive()).unwrap_or(1);
456 for version in range.start..range_end {
457 let mut new_tables: BTreeMap<usize, Table> = BTreeMap::new();
458
459 let mut mod_output = TokenStream::new();
460 for (i, table) in item.variants.iter().enumerate() {
461 let mut other_attrs = vec![];
462 let mut uniques = vec![];
463 let mut referer = true;
464 for attr in &table.attrs {
465 if let Some(unique) = is_unique(attr.path()) {
466 let idents = attr.parse_args_with(
467 Punctuated::<Ident, Token![,]>::parse_separated_nonempty,
468 )?;
469 uniques.push(Unique {
470 name: unique,
471 columns: idents.into_iter().collect(),
472 })
473 } else if attr.path().is_ident("no_reference") {
474 if version + 1 == range_end {
476 referer = false;
477 }
478 } else {
479 other_attrs.push(attr.clone());
480 }
481 }
482
483 let range = parse_version(&other_attrs)?;
484 if !range.includes(version) {
485 continue;
486 }
487 let mut prev = None;
488 if version != range.start {
490 prev = Some(table.ident.clone());
492 }
493
494 let mut columns = BTreeMap::new();
495 for (i, field) in table.fields.iter().enumerate() {
496 let Some(name) = field.ident.clone() else {
497 return Err(syn::Error::new_spanned(
498 field,
499 "Expected table columns to be named.",
500 ));
501 };
502 if name.to_string().to_lowercase() == "id" {
504 return Err(syn::Error::new_spanned(
505 name,
506 "The `id` column is reserved to be used by rust-query internally.",
507 ));
508 }
509 let mut other_attrs = vec![];
510 let mut unique = None;
511 for attr in &field.attrs {
512 if let Some(unique_name) = is_unique(attr.path()) {
513 let Meta::Path(_) = &attr.meta else {
514 return Err(syn::Error::new_spanned(
515 attr,
516 "Expected no arguments for field specific unique attribute.",
517 ));
518 };
519 unique = Some(Unique {
520 name: unique_name,
521 columns: vec![name.clone()],
522 })
523 } else {
524 other_attrs.push(attr.clone());
525 }
526 }
527 let range = parse_version(&other_attrs)?;
528 if !range.includes(version) {
529 continue;
530 }
531 let col = Column {
532 name,
533 typ: field.ty.clone(),
534 };
535 columns.insert(i, col);
536 uniques.extend(unique);
537 }
538
539 let table = Table {
540 referer,
541 prev,
542 name: table.ident.clone(),
543 columns,
544 uniques,
545 };
546
547 mod_output.extend(table::define_table(&table, schema)?);
548 new_tables.insert(i, table);
549 }
550
551 let mut schema_table_typs = vec![];
552
553 let mut table_defs = vec![];
554 let mut tables = vec![];
555
556 let mut table_migrations = TokenStream::new();
557
558 for (i, table) in &new_tables {
560 let table_name = &table.name;
561
562 let table_lower = to_lower(table_name);
563
564 schema_table_typs.push(quote! {b.table::<#table_name>()});
565
566 if let Some(prev_table) = prev_tables.remove(i) {
567 let Some(migration) = define_table_migration(Some(&prev_table.columns), table)
570 else {
571 continue;
572 };
573 table_migrations.extend(migration);
574
575 table_defs.push(quote! {
576 pub #table_lower: ::rust_query::private::M<'t, #table_name, super::#table_name>
577 });
578 tables.push(quote! {b.migrate_table(self.#table_lower)});
579 } else {
580 let Some(migration) = define_table_migration(None, table) else {
581 return Err(syn::Error::new_spanned(
582 &table.name,
583 "Empty tables are not supported (yet).",
584 ));
585 };
586 table_migrations.extend(migration);
587
588 table_defs.push(quote! {
589 pub #table_lower: ::rust_query::private::C<'t, _PrevSchema, super::#table_name>
590 });
591 tables.push(quote! {b.create_from(self.#table_lower)});
592 }
593 }
594 for prev_table in prev_tables.into_values() {
595 let table_ident = &prev_table.name;
598 tables.push(quote! {b.drop_table::<super::super::#prev_mod::#table_ident>()})
599 }
600
601 let version_i64 = version as i64;
602 mod_output.extend(quote! {
603 pub struct #schema;
604 impl ::rust_query::private::Schema for #schema {
605 const VERSION: i64 = #version_i64;
606
607 fn typs(b: &mut ::rust_query::private::TableTypBuilder<Self>) {
608 #(#schema_table_typs;)*
609 }
610 }
611
612 pub fn assert_hash(expect: ::rust_query::private::Expect) {
613 expect.assert_eq(&::rust_query::private::hash_schema::<#schema>())
614 }
615 });
616
617 let new_mod = format_ident!("v{version}");
618
619 let migrations = prev_mod.map(|prev_mod| {
620 let prelude = prelude(&new_tables, &prev_mod, schema);
621
622 let lifetime = table_defs.is_empty().not().then_some(quote! {'t,});
623 quote! {
624 pub mod update {
625 #prelude
626
627 #table_migrations
628
629 pub struct #schema<#lifetime> {
630 #(#table_defs,)*
631 }
632
633 impl<'t> ::rust_query::private::Migration<'t> for #schema<#lifetime> {
634 type From = _PrevSchema;
635 type To = super::#schema;
636
637 fn tables(self, b: &mut ::rust_query::private::SchemaBuilder<'_, 't>) {
638 #(#tables;)*
639 }
640 }
641 }
642 }
643 });
644
645 output.extend(quote! {
646 mod #new_mod {
647 #mod_output
648
649 #migrations
650 }
651 });
652
653 prev_tables = new_tables;
654 prev_mod = Some(new_mod);
655 }
656
657 Ok(output)
658}
659
660fn prelude(new_tables: &BTreeMap<usize, Table>, prev_mod: &Ident, schema: &Ident) -> TokenStream {
661 let mut prelude = vec![];
662 for table in new_tables.values() {
663 let Some(old_name) = &table.prev else {
664 continue;
665 };
666 let new_name = &table.name;
667 prelude.push(quote! {
668 #old_name as #new_name
669 });
670 }
671 prelude.push(quote! {#schema as _PrevSchema});
672 let mut prelude = quote! {
673 #[allow(unused_imports)]
674 use super::super::#prev_mod::{#(#prelude,)*};
675 };
676 for table in new_tables.values() {
677 if table.prev.is_none() {
678 let new_name = &table.name;
679 prelude.extend(quote! {
680 #[allow(unused_imports)]
681 use ::rust_query::migration::NoTable as #new_name;
682 })
683 }
684 }
685 prelude
686}