1use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12 pool: Expr,
13 version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let pool = input.parse()?;
19 let mut version_table = None;
20
21 while !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 if input.is_empty() {
24 break;
25 }
26
27 let key: Ident = input.parse()?;
28 if key == "version_table" {
29 input.parse::<Token![=]>()?;
30 let value: syn::LitStr = input.parse()?;
31 version_table = Some(value.value());
32 } else {
33 return Err(syn::Error::new(
34 key.span(),
35 "unsupported option for vespertide_migration!",
36 ));
37 }
38 }
39
40 Ok(MacroInput {
41 pool,
42 version_table,
43 })
44 }
45}
46
47pub(crate) fn build_migration_block(
50 migration: &vespertide_core::MigrationPlan,
51 baseline_schema: &mut Vec<vespertide_core::TableDef>,
52) -> Result<proc_macro2::TokenStream, String> {
53 let version = migration.version;
54
55 let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
57 format!(
58 "Failed to build queries for migration version {}: {}",
59 version, e
60 )
61 })?;
62
63 for action in &migration.actions {
65 let _ = apply_action(baseline_schema, action);
66 }
67
68 let mut pg_sqls = Vec::new();
71 let mut mysql_sqls = Vec::new();
72 let mut sqlite_sqls = Vec::new();
73
74 for q in &queries {
75 for stmt in &q.postgres {
76 pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
77 }
78 for stmt in &q.mysql {
79 mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
80 }
81 for stmt in &q.sqlite {
82 sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
83 }
84 }
85
86 let block = quote! {
88 if version < #version {
89 let txn = __pool.begin().await.map_err(|e| {
91 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
92 })?;
93
94 let sqls: &[&str] = match backend {
96 sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
97 sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
98 sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
99 _ => &[#(#pg_sqls),*], };
101
102 for sql in sqls {
104 if !sql.is_empty() {
105 let stmt = sea_orm::Statement::from_string(backend, *sql);
106 txn.execute_raw(stmt).await.map_err(|e| {
107 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
108 })?;
109 }
110 }
111
112 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
114 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
115 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
116 txn.execute_raw(stmt).await.map_err(|e| {
117 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
118 })?;
119
120 txn.commit().await.map_err(|e| {
122 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
123 })?;
124 }
125 };
126
127 Ok(block)
128}
129
130pub(crate) fn generate_migration_code(
132 pool: &Expr,
133 version_table: &str,
134 migration_blocks: Vec<proc_macro2::TokenStream>,
135) -> proc_macro2::TokenStream {
136 quote! {
137 async {
138 use sea_orm::{ConnectionTrait, TransactionTrait};
139 let __pool = #pool;
140 let version_table = #version_table;
141 let backend = __pool.get_database_backend();
142
143 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
146 let create_table_sql = format!(
147 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
148 version_table
149 );
150 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
151 __pool.execute_raw(stmt).await.map_err(|e| {
152 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
153 })?;
154
155 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
157 let stmt = sea_orm::Statement::from_string(backend, select_sql);
158 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
159 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
160 })?;
161
162 let mut version = version_result
163 .and_then(|row| row.try_get::<i32>("", "version").ok())
164 .unwrap_or(0) as u32;
165
166 #(#migration_blocks)*
168
169 Ok::<(), ::vespertide::MigrationError>(())
170 }
171 }
172}
173
174pub(crate) fn vespertide_migration_impl(
176 input: proc_macro2::TokenStream,
177) -> proc_macro2::TokenStream {
178 let input: MacroInput = match syn::parse2(input) {
179 Ok(input) => input,
180 Err(e) => return e.to_compile_error(),
181 };
182 let pool = &input.pool;
183 let version_table = input
184 .version_table
185 .unwrap_or_else(|| "vespertide_version".to_string());
186
187 let migrations = match load_migrations_at_compile_time() {
189 Ok(migrations) => migrations,
190 Err(e) => {
191 return syn::Error::new(
192 proc_macro2::Span::call_site(),
193 format!("Failed to load migrations at compile time: {}", e),
194 )
195 .to_compile_error();
196 }
197 };
198 let _models = match load_models_at_compile_time() {
199 Ok(models) => models,
200 Err(e) => {
201 return syn::Error::new(
202 proc_macro2::Span::call_site(),
203 format!("Failed to load models at compile time: {}", e),
204 )
205 .to_compile_error();
206 }
207 };
208
209 let mut baseline_schema = Vec::new();
211 let mut migration_blocks = Vec::new();
212
213 for migration in &migrations {
214 match build_migration_block(migration, &mut baseline_schema) {
215 Ok(block) => migration_blocks.push(block),
216 Err(e) => {
217 return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
218 }
219 }
220 }
221
222 generate_migration_code(pool, &version_table, migration_blocks)
223}
224
225#[proc_macro]
227pub fn vespertide_migration(input: TokenStream) -> TokenStream {
228 vespertide_migration_impl(input.into()).into()
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::fs::File;
235 use std::io::Write;
236 use tempfile::tempdir;
237 use vespertide_core::{
238 ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
239 };
240
241 #[test]
242 fn test_macro_expansion_with_runtime_macros() {
243 let dir = tempdir().unwrap();
245
246 let test_file_path = dir.path().join("test_macro.rs");
248 let mut test_file = File::create(&test_file_path).unwrap();
249 writeln!(
250 test_file,
251 r#"vespertide_migration!(pool, version_table = "test_versions");"#
252 )
253 .unwrap();
254
255 let file = File::open(&test_file_path).unwrap();
257 let result = runtime_macros::emulate_functionlike_macro_expansion(
258 file,
259 &[("vespertide_migration", vespertide_migration_impl)],
260 );
261
262 assert!(result.is_ok() || result.is_err());
266 }
267
268 #[test]
269 fn test_macro_with_simple_pool() {
270 let dir = tempdir().unwrap();
271 let test_file_path = dir.path().join("test_simple.rs");
272 let mut test_file = File::create(&test_file_path).unwrap();
273 writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
274
275 let file = File::open(&test_file_path).unwrap();
276 let result = runtime_macros::emulate_functionlike_macro_expansion(
277 file,
278 &[("vespertide_migration", vespertide_migration_impl)],
279 );
280
281 assert!(result.is_ok() || result.is_err());
282 }
283
284 #[test]
285 fn test_macro_parsing_invalid_option() {
286 let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
288 let output = vespertide_migration_impl(input);
289 let output_str = output.to_string();
290 assert!(output_str.contains("unsupported option"));
292 }
293
294 #[test]
295 fn test_macro_parsing_valid_input() {
296 let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
300 let output = vespertide_migration_impl(input);
301 let output_str = output.to_string();
302 assert!(!output_str.is_empty());
304 assert!(
307 output_str.contains("async") || output_str.contains("Failed to load"),
308 "Unexpected output: {}",
309 output_str
310 );
311 }
312
313 #[test]
314 fn test_macro_parsing_with_version_table() {
315 let input: proc_macro2::TokenStream =
316 r#"pool, version_table = "custom_versions""#.parse().unwrap();
317 let output = vespertide_migration_impl(input);
318 let output_str = output.to_string();
319 assert!(!output_str.is_empty());
320 }
321
322 #[test]
323 fn test_macro_parsing_trailing_comma() {
324 let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
325 let output = vespertide_migration_impl(input);
326 let output_str = output.to_string();
327 assert!(!output_str.is_empty());
328 }
329
330 fn test_column(name: &str) -> ColumnDef {
331 ColumnDef {
332 name: name.into(),
333 r#type: ColumnType::Simple(SimpleColumnType::Integer),
334 nullable: false,
335 default: None,
336 comment: None,
337 primary_key: None,
338 unique: None,
339 index: None,
340 foreign_key: None,
341 }
342 }
343
344 #[test]
345 fn test_build_migration_block_create_table() {
346 let migration = MigrationPlan {
347 version: 1,
348 comment: None,
349 created_at: None,
350 actions: vec![MigrationAction::CreateTable {
351 table: "users".into(),
352 columns: vec![test_column("id")],
353 constraints: vec![],
354 }],
355 };
356
357 let mut baseline = Vec::new();
358 let result = build_migration_block(&migration, &mut baseline);
359
360 assert!(result.is_ok());
361 let block = result.unwrap();
362 let block_str = block.to_string();
363
364 assert!(block_str.contains("version < 1u32"));
366 assert!(block_str.contains("CREATE TABLE"));
367
368 assert_eq!(baseline.len(), 1);
370 assert_eq!(baseline[0].name, "users");
371 }
372
373 #[test]
374 fn test_build_migration_block_add_column() {
375 let create_migration = MigrationPlan {
377 version: 1,
378 comment: None,
379 created_at: None,
380 actions: vec![MigrationAction::CreateTable {
381 table: "users".into(),
382 columns: vec![test_column("id")],
383 constraints: vec![],
384 }],
385 };
386
387 let mut baseline = Vec::new();
388 let _ = build_migration_block(&create_migration, &mut baseline);
389
390 let add_column_migration = MigrationPlan {
392 version: 2,
393 comment: None,
394 created_at: None,
395 actions: vec![MigrationAction::AddColumn {
396 table: "users".into(),
397 column: Box::new(ColumnDef {
398 name: "email".into(),
399 r#type: ColumnType::Simple(SimpleColumnType::Text),
400 nullable: true,
401 default: None,
402 comment: None,
403 primary_key: None,
404 unique: None,
405 index: None,
406 foreign_key: None,
407 }),
408 fill_with: None,
409 }],
410 };
411
412 let result = build_migration_block(&add_column_migration, &mut baseline);
413 assert!(result.is_ok());
414 let block = result.unwrap();
415 let block_str = block.to_string();
416
417 assert!(block_str.contains("version < 2u32"));
418 assert!(block_str.contains("ALTER TABLE"));
419 assert!(block_str.contains("ADD COLUMN"));
420 }
421
422 #[test]
423 fn test_build_migration_block_multiple_actions() {
424 let migration = MigrationPlan {
425 version: 1,
426 comment: None,
427 created_at: None,
428 actions: vec![
429 MigrationAction::CreateTable {
430 table: "users".into(),
431 columns: vec![test_column("id")],
432 constraints: vec![],
433 },
434 MigrationAction::CreateTable {
435 table: "posts".into(),
436 columns: vec![test_column("id")],
437 constraints: vec![],
438 },
439 ],
440 };
441
442 let mut baseline = Vec::new();
443 let result = build_migration_block(&migration, &mut baseline);
444
445 assert!(result.is_ok());
446 assert_eq!(baseline.len(), 2);
447 }
448
449 #[test]
450 fn test_generate_migration_code() {
451 let pool: Expr = syn::parse_str("db_pool").unwrap();
452 let version_table = "test_versions";
453
454 let migration = MigrationPlan {
456 version: 1,
457 comment: None,
458 created_at: None,
459 actions: vec![MigrationAction::CreateTable {
460 table: "users".into(),
461 columns: vec![test_column("id")],
462 constraints: vec![],
463 }],
464 };
465
466 let mut baseline = Vec::new();
467 let block = build_migration_block(&migration, &mut baseline).unwrap();
468
469 let generated = generate_migration_code(&pool, version_table, vec![block]);
470 let generated_str = generated.to_string();
471
472 assert!(generated_str.contains("async"));
474 assert!(generated_str.contains("db_pool"));
475 assert!(generated_str.contains("test_versions"));
476 assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
477 assert!(generated_str.contains("SELECT MAX"));
478 }
479
480 #[test]
481 fn test_generate_migration_code_empty_migrations() {
482 let pool: Expr = syn::parse_str("pool").unwrap();
483 let version_table = "vespertide_version";
484
485 let generated = generate_migration_code(&pool, version_table, vec![]);
486 let generated_str = generated.to_string();
487
488 assert!(generated_str.contains("async"));
490 assert!(generated_str.contains("vespertide_version"));
491 }
492
493 #[test]
494 fn test_generate_migration_code_multiple_blocks() {
495 let pool: Expr = syn::parse_str("connection").unwrap();
496
497 let mut baseline = Vec::new();
498
499 let migration1 = MigrationPlan {
500 version: 1,
501 comment: None,
502 created_at: None,
503 actions: vec![MigrationAction::CreateTable {
504 table: "users".into(),
505 columns: vec![test_column("id")],
506 constraints: vec![],
507 }],
508 };
509 let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
510
511 let migration2 = MigrationPlan {
512 version: 2,
513 comment: None,
514 created_at: None,
515 actions: vec![MigrationAction::CreateTable {
516 table: "posts".into(),
517 columns: vec![test_column("id")],
518 constraints: vec![],
519 }],
520 };
521 let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
522
523 let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
524 let generated_str = generated.to_string();
525
526 assert!(generated_str.contains("version < 1u32"));
528 assert!(generated_str.contains("version < 2u32"));
529 }
530
531 #[test]
532 fn test_build_migration_block_generates_all_backends() {
533 let migration = MigrationPlan {
534 version: 1,
535 comment: None,
536 created_at: None,
537 actions: vec![MigrationAction::CreateTable {
538 table: "test_table".into(),
539 columns: vec![test_column("id")],
540 constraints: vec![],
541 }],
542 };
543
544 let mut baseline = Vec::new();
545 let result = build_migration_block(&migration, &mut baseline);
546 assert!(result.is_ok());
547
548 let block_str = result.unwrap().to_string();
549
550 assert!(block_str.contains("DatabaseBackend :: Postgres"));
552 assert!(block_str.contains("DatabaseBackend :: MySql"));
553 assert!(block_str.contains("DatabaseBackend :: Sqlite"));
554 }
555
556 #[test]
557 fn test_build_migration_block_with_delete_table() {
558 let create_migration = MigrationPlan {
560 version: 1,
561 comment: None,
562 created_at: None,
563 actions: vec![MigrationAction::CreateTable {
564 table: "temp_table".into(),
565 columns: vec![test_column("id")],
566 constraints: vec![],
567 }],
568 };
569
570 let mut baseline = Vec::new();
571 let _ = build_migration_block(&create_migration, &mut baseline);
572 assert_eq!(baseline.len(), 1);
573
574 let delete_migration = MigrationPlan {
576 version: 2,
577 comment: None,
578 created_at: None,
579 actions: vec![MigrationAction::DeleteTable {
580 table: "temp_table".into(),
581 }],
582 };
583
584 let result = build_migration_block(&delete_migration, &mut baseline);
585 assert!(result.is_ok());
586 let block_str = result.unwrap().to_string();
587 assert!(block_str.contains("DROP TABLE"));
588
589 assert_eq!(baseline.len(), 0);
591 }
592
593 #[test]
594 fn test_build_migration_block_with_index() {
595 let migration = MigrationPlan {
596 version: 1,
597 comment: None,
598 created_at: None,
599 actions: vec![MigrationAction::CreateTable {
600 table: "users".into(),
601 columns: vec![
602 test_column("id"),
603 ColumnDef {
604 name: "email".into(),
605 r#type: ColumnType::Simple(SimpleColumnType::Text),
606 nullable: true,
607 default: None,
608 comment: None,
609 primary_key: None,
610 unique: None,
611 index: Some(StrOrBoolOrArray::Bool(true)),
612 foreign_key: None,
613 },
614 ],
615 constraints: vec![],
616 }],
617 };
618
619 let mut baseline = Vec::new();
620 let result = build_migration_block(&migration, &mut baseline);
621 assert!(result.is_ok());
622
623 let table = &baseline[0];
625 let normalized = table.clone().normalize();
626 assert!(normalized.is_ok());
627 }
628
629 #[test]
630 fn test_build_migration_block_error_nonexistent_table() {
631 let migration = MigrationPlan {
633 version: 1,
634 comment: None,
635 created_at: None,
636 actions: vec![MigrationAction::AddColumn {
637 table: "nonexistent_table".into(),
638 column: Box::new(test_column("new_col")),
639 fill_with: None,
640 }],
641 };
642
643 let mut baseline = Vec::new();
644 let result = build_migration_block(&migration, &mut baseline);
645
646 assert!(result.is_err());
647 let err = result.unwrap_err();
648 assert!(err.contains("Failed to build queries for migration version 1"));
649 }
650
651 #[test]
652 fn test_vespertide_migration_impl_loading_error() {
653 let original = std::env::var("CARGO_MANIFEST_DIR").ok();
655
656 unsafe {
658 std::env::remove_var("CARGO_MANIFEST_DIR");
659 }
660
661 let input: proc_macro2::TokenStream = "pool".parse().unwrap();
662 let output = vespertide_migration_impl(input);
663 let output_str = output.to_string();
664
665 assert!(
667 output_str.contains("Failed to load migrations at compile time"),
668 "Expected loading error, got: {}",
669 output_str
670 );
671
672 if let Some(val) = original {
674 unsafe {
675 std::env::set_var("CARGO_MANIFEST_DIR", val);
676 }
677 }
678 }
679}