Skip to main content

vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use std::env;
4use std::path::PathBuf;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::{Parse, ParseStream};
9use syn::{Expr, Ident, Token};
10use vespertide_loader::{
11    load_config_or_default, load_migrations_at_compile_time, load_models_at_compile_time,
12};
13use vespertide_planner::apply_action;
14use vespertide_query::{DatabaseBackend, build_plan_queries};
15
16struct MacroInput {
17    pool: Expr,
18    version_table: Option<String>,
19}
20
21impl Parse for MacroInput {
22    fn parse(input: ParseStream) -> syn::Result<Self> {
23        let pool = input.parse()?;
24        let mut version_table = None;
25
26        while !input.is_empty() {
27            input.parse::<Token![,]>()?;
28            if input.is_empty() {
29                break;
30            }
31
32            let key: Ident = input.parse()?;
33            if key == "version_table" {
34                input.parse::<Token![=]>()?;
35                let value: syn::LitStr = input.parse()?;
36                version_table = Some(value.value());
37            } else {
38                return Err(syn::Error::new(
39                    key.span(),
40                    "unsupported option for vespertide_migration!",
41                ));
42            }
43        }
44
45        Ok(MacroInput {
46            pool,
47            version_table,
48        })
49    }
50}
51
52/// Build a migration block for a single migration version.
53/// Returns the generated code block and updates the baseline schema.
54pub(crate) fn build_migration_block(
55    migration: &vespertide_core::MigrationPlan,
56    baseline_schema: &mut Vec<vespertide_core::TableDef>,
57) -> Result<proc_macro2::TokenStream, String> {
58    let version = migration.version;
59
60    // Use the current baseline schema (from all previous migrations)
61    let queries = build_plan_queries(migration, baseline_schema).map_err(|e| {
62        format!(
63            "Failed to build queries for migration version {}: {}",
64            version, e
65        )
66    })?;
67
68    // Update baseline schema incrementally by applying each action
69    for action in &migration.actions {
70        let _ = apply_action(baseline_schema, action);
71    }
72
73    // Pre-generate SQL for all backends at compile time
74    // Each query may produce multiple SQL statements, so we flatten them
75    let mut pg_sqls = Vec::new();
76    let mut mysql_sqls = Vec::new();
77    let mut sqlite_sqls = Vec::new();
78
79    for q in &queries {
80        for stmt in &q.postgres {
81            pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
82        }
83        for stmt in &q.mysql {
84            mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
85        }
86        for stmt in &q.sqlite {
87            sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
88        }
89    }
90
91    // Generate version guard and SQL execution block
92    let block = quote! {
93        if version < #version {
94            // Begin transaction
95            let txn = __pool.begin().await.map_err(|e| {
96                ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
97            })?;
98
99            // Select SQL statements based on backend
100            let sqls: &[&str] = match backend {
101                sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
102                sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
103                sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
104                _ => &[#(#pg_sqls),*], // Fallback to PostgreSQL syntax for unknown backends
105            };
106
107            // Execute SQL statements
108            for sql in sqls {
109                if !sql.is_empty() {
110                    let stmt = sea_orm::Statement::from_string(backend, *sql);
111                    txn.execute_raw(stmt).await.map_err(|e| {
112                        ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
113                    })?;
114                }
115            }
116
117            // Insert version record for this migration
118            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
119            let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
120            let stmt = sea_orm::Statement::from_string(backend, insert_sql);
121            txn.execute_raw(stmt).await.map_err(|e| {
122                ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
123            })?;
124
125            // Commit transaction
126            txn.commit().await.map_err(|e| {
127                ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
128            })?;
129        }
130    };
131
132    Ok(block)
133}
134
135/// Generate the final async migration block with all migrations.
136pub(crate) fn generate_migration_code(
137    pool: &Expr,
138    version_table: &str,
139    migration_blocks: Vec<proc_macro2::TokenStream>,
140) -> proc_macro2::TokenStream {
141    quote! {
142        async {
143            use sea_orm::{ConnectionTrait, TransactionTrait};
144            let __pool = #pool;
145            let version_table = #version_table;
146            let backend = __pool.get_database_backend();
147
148            // Create version table if it does not exist
149            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
150            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
151            let create_table_sql = format!(
152                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
153                version_table
154            );
155            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
156            __pool.execute_raw(stmt).await.map_err(|e| {
157                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
158            })?;
159
160            // Read current maximum version (latest applied migration)
161            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
162            let stmt = sea_orm::Statement::from_string(backend, select_sql);
163            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
164                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
165            })?;
166
167            let mut version = version_result
168                .and_then(|row| row.try_get::<i32>("", "version").ok())
169                .unwrap_or(0) as u32;
170
171            // Execute each migration block
172            #(#migration_blocks)*
173
174            Ok::<(), ::vespertide::MigrationError>(())
175        }
176    }
177}
178
179/// Inner implementation that works with proc_macro2::TokenStream for testability.
180pub(crate) fn vespertide_migration_impl(
181    input: proc_macro2::TokenStream,
182) -> proc_macro2::TokenStream {
183    let input: MacroInput = match syn::parse2(input) {
184        Ok(input) => input,
185        Err(e) => return e.to_compile_error(),
186    };
187    let pool = &input.pool;
188
189    // Get project root from CARGO_MANIFEST_DIR (same as load_migrations_at_compile_time)
190    let project_root = match env::var("CARGO_MANIFEST_DIR") {
191        Ok(dir) => Some(PathBuf::from(dir)),
192        Err(_) => None,
193    };
194
195    // Load config to get prefix
196    let config = match load_config_or_default(project_root) {
197        Ok(config) => config,
198        #[cfg(not(tarpaulin_include))]
199        Err(e) => {
200            return syn::Error::new(
201                proc_macro2::Span::call_site(),
202                format!("Failed to load config at compile time: {}", e),
203            )
204            .to_compile_error();
205        }
206    };
207    let prefix = config.prefix();
208
209    // Apply prefix to version_table if not explicitly provided
210    let version_table = input
211        .version_table
212        .map(|vt| config.apply_prefix(&vt))
213        .unwrap_or_else(|| config.apply_prefix("vespertide_version"));
214
215    // Load migration files and build SQL at compile time
216    let migrations = match load_migrations_at_compile_time() {
217        Ok(migrations) => migrations,
218        Err(e) => {
219            return syn::Error::new(
220                proc_macro2::Span::call_site(),
221                format!("Failed to load migrations at compile time: {}", e),
222            )
223            .to_compile_error();
224        }
225    };
226    let _models = match load_models_at_compile_time() {
227        Ok(models) => models,
228        #[cfg(not(tarpaulin_include))]
229        Err(e) => {
230            return syn::Error::new(
231                proc_macro2::Span::call_site(),
232                format!("Failed to load models at compile time: {}", e),
233            )
234            .to_compile_error();
235        }
236    };
237
238    // Apply prefix to migrations and build SQL using incremental baseline schema
239    let mut baseline_schema = Vec::new();
240    let mut migration_blocks = Vec::new();
241
242    #[cfg(not(tarpaulin_include))]
243    for migration in &migrations {
244        // Apply prefix to migration table names
245        let prefixed_migration = migration.clone().with_prefix(prefix);
246        match build_migration_block(&prefixed_migration, &mut baseline_schema) {
247            Ok(block) => migration_blocks.push(block),
248            Err(e) => {
249                return syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error();
250            }
251        }
252    }
253
254    generate_migration_code(pool, &version_table, migration_blocks)
255}
256
257/// Zero-runtime migration entry point.
258#[cfg(not(tarpaulin_include))]
259#[proc_macro]
260pub fn vespertide_migration(input: TokenStream) -> TokenStream {
261    vespertide_migration_impl(input.into()).into()
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use std::fs::File;
268    use std::io::Write;
269    use tempfile::tempdir;
270    use vespertide_core::{
271        ColumnDef, ColumnType, MigrationAction, MigrationPlan, SimpleColumnType, StrOrBoolOrArray,
272    };
273
274    #[test]
275    fn test_macro_expansion_with_runtime_macros() {
276        // Create a temporary directory with test files
277        let dir = tempdir().unwrap();
278
279        // Create a test file that uses the macro
280        let test_file_path = dir.path().join("test_macro.rs");
281        let mut test_file = File::create(&test_file_path).unwrap();
282        writeln!(
283            test_file,
284            r#"vespertide_migration!(pool, version_table = "test_versions");"#
285        )
286        .unwrap();
287
288        // Use runtime-macros to emulate macro expansion
289        let file = File::open(&test_file_path).unwrap();
290        let result = runtime_macros::emulate_functionlike_macro_expansion(
291            file,
292            &[("vespertide_migration", vespertide_migration_impl)],
293        );
294
295        // The macro will fail because there's no vespertide config, but
296        // the important thing is that it runs and covers the macro code
297        // We expect an error due to missing config
298        assert!(result.is_ok() || result.is_err());
299    }
300
301    #[test]
302    fn test_macro_with_simple_pool() {
303        let dir = tempdir().unwrap();
304        let test_file_path = dir.path().join("test_simple.rs");
305        let mut test_file = File::create(&test_file_path).unwrap();
306        writeln!(test_file, r#"vespertide_migration!(db_pool);"#).unwrap();
307
308        let file = File::open(&test_file_path).unwrap();
309        let result = runtime_macros::emulate_functionlike_macro_expansion(
310            file,
311            &[("vespertide_migration", vespertide_migration_impl)],
312        );
313
314        assert!(result.is_ok() || result.is_err());
315    }
316
317    #[test]
318    fn test_macro_parsing_invalid_option() {
319        // Test that invalid options produce a compile error
320        let input: proc_macro2::TokenStream = "pool, invalid_option = \"value\"".parse().unwrap();
321        let output = vespertide_migration_impl(input);
322        let output_str = output.to_string();
323        // Should contain an error message about unsupported option
324        assert!(output_str.contains("unsupported option"));
325    }
326
327    #[test]
328    fn test_macro_parsing_valid_input() {
329        // Test that valid input is parsed correctly
330        // The macro will either succeed (if migrations dir exists and is empty)
331        // or fail with a migration loading error
332        let input: proc_macro2::TokenStream = "my_pool".parse().unwrap();
333        let output = vespertide_migration_impl(input);
334        let output_str = output.to_string();
335        // Should produce output (either success or migration loading error)
336        assert!(!output_str.is_empty());
337        // If error, it should mention "Failed to load"
338        // If success, it should contain "async"
339        assert!(
340            output_str.contains("async") || output_str.contains("Failed to load"),
341            "Unexpected output: {}",
342            output_str
343        );
344    }
345
346    #[test]
347    fn test_macro_parsing_with_version_table() {
348        let input: proc_macro2::TokenStream =
349            r#"pool, version_table = "custom_versions""#.parse().unwrap();
350        let output = vespertide_migration_impl(input);
351        let output_str = output.to_string();
352        assert!(!output_str.is_empty());
353    }
354
355    #[test]
356    fn test_macro_parsing_trailing_comma() {
357        let input: proc_macro2::TokenStream = "pool,".parse().unwrap();
358        let output = vespertide_migration_impl(input);
359        let output_str = output.to_string();
360        assert!(!output_str.is_empty());
361    }
362
363    fn test_column(name: &str) -> ColumnDef {
364        ColumnDef {
365            name: name.into(),
366            r#type: ColumnType::Simple(SimpleColumnType::Integer),
367            nullable: false,
368            default: None,
369            comment: None,
370            primary_key: None,
371            unique: None,
372            index: None,
373            foreign_key: None,
374        }
375    }
376
377    #[test]
378    fn test_build_migration_block_create_table() {
379        let migration = MigrationPlan {
380            version: 1,
381            comment: None,
382            created_at: None,
383            actions: vec![MigrationAction::CreateTable {
384                table: "users".into(),
385                columns: vec![test_column("id")],
386                constraints: vec![],
387            }],
388        };
389
390        let mut baseline = Vec::new();
391        let result = build_migration_block(&migration, &mut baseline);
392
393        assert!(result.is_ok());
394        let block = result.unwrap();
395        let block_str = block.to_string();
396
397        // Verify the generated block contains expected elements
398        assert!(block_str.contains("version < 1u32"));
399        assert!(block_str.contains("CREATE TABLE"));
400
401        // Verify baseline schema was updated
402        assert_eq!(baseline.len(), 1);
403        assert_eq!(baseline[0].name, "users");
404    }
405
406    #[test]
407    fn test_build_migration_block_add_column() {
408        // First create the table
409        let create_migration = MigrationPlan {
410            version: 1,
411            comment: None,
412            created_at: None,
413            actions: vec![MigrationAction::CreateTable {
414                table: "users".into(),
415                columns: vec![test_column("id")],
416                constraints: vec![],
417            }],
418        };
419
420        let mut baseline = Vec::new();
421        let _ = build_migration_block(&create_migration, &mut baseline);
422
423        // Now add a column
424        let add_column_migration = MigrationPlan {
425            version: 2,
426            comment: None,
427            created_at: None,
428            actions: vec![MigrationAction::AddColumn {
429                table: "users".into(),
430                column: Box::new(ColumnDef {
431                    name: "email".into(),
432                    r#type: ColumnType::Simple(SimpleColumnType::Text),
433                    nullable: true,
434                    default: None,
435                    comment: None,
436                    primary_key: None,
437                    unique: None,
438                    index: None,
439                    foreign_key: None,
440                }),
441                fill_with: None,
442            }],
443        };
444
445        let result = build_migration_block(&add_column_migration, &mut baseline);
446        assert!(result.is_ok());
447        let block = result.unwrap();
448        let block_str = block.to_string();
449
450        assert!(block_str.contains("version < 2u32"));
451        assert!(block_str.contains("ALTER TABLE"));
452        assert!(block_str.contains("ADD COLUMN"));
453    }
454
455    #[test]
456    fn test_build_migration_block_multiple_actions() {
457        let migration = MigrationPlan {
458            version: 1,
459            comment: None,
460            created_at: None,
461            actions: vec![
462                MigrationAction::CreateTable {
463                    table: "users".into(),
464                    columns: vec![test_column("id")],
465                    constraints: vec![],
466                },
467                MigrationAction::CreateTable {
468                    table: "posts".into(),
469                    columns: vec![test_column("id")],
470                    constraints: vec![],
471                },
472            ],
473        };
474
475        let mut baseline = Vec::new();
476        let result = build_migration_block(&migration, &mut baseline);
477
478        assert!(result.is_ok());
479        assert_eq!(baseline.len(), 2);
480    }
481
482    #[test]
483    fn test_generate_migration_code() {
484        let pool: Expr = syn::parse_str("db_pool").unwrap();
485        let version_table = "test_versions";
486
487        // Create a simple migration block
488        let migration = MigrationPlan {
489            version: 1,
490            comment: None,
491            created_at: None,
492            actions: vec![MigrationAction::CreateTable {
493                table: "users".into(),
494                columns: vec![test_column("id")],
495                constraints: vec![],
496            }],
497        };
498
499        let mut baseline = Vec::new();
500        let block = build_migration_block(&migration, &mut baseline).unwrap();
501
502        let generated = generate_migration_code(&pool, version_table, vec![block]);
503        let generated_str = generated.to_string();
504
505        // Verify the generated code structure
506        assert!(generated_str.contains("async"));
507        assert!(generated_str.contains("db_pool"));
508        assert!(generated_str.contains("test_versions"));
509        assert!(generated_str.contains("CREATE TABLE IF NOT EXISTS"));
510        assert!(generated_str.contains("SELECT MAX"));
511    }
512
513    #[test]
514    fn test_generate_migration_code_empty_migrations() {
515        let pool: Expr = syn::parse_str("pool").unwrap();
516        let version_table = "vespertide_version";
517
518        let generated = generate_migration_code(&pool, version_table, vec![]);
519        let generated_str = generated.to_string();
520
521        // Should still generate the wrapper code
522        assert!(generated_str.contains("async"));
523        assert!(generated_str.contains("vespertide_version"));
524    }
525
526    #[test]
527    fn test_generate_migration_code_multiple_blocks() {
528        let pool: Expr = syn::parse_str("connection").unwrap();
529
530        let mut baseline = Vec::new();
531
532        let migration1 = MigrationPlan {
533            version: 1,
534            comment: None,
535            created_at: None,
536            actions: vec![MigrationAction::CreateTable {
537                table: "users".into(),
538                columns: vec![test_column("id")],
539                constraints: vec![],
540            }],
541        };
542        let block1 = build_migration_block(&migration1, &mut baseline).unwrap();
543
544        let migration2 = MigrationPlan {
545            version: 2,
546            comment: None,
547            created_at: None,
548            actions: vec![MigrationAction::CreateTable {
549                table: "posts".into(),
550                columns: vec![test_column("id")],
551                constraints: vec![],
552            }],
553        };
554        let block2 = build_migration_block(&migration2, &mut baseline).unwrap();
555
556        let generated = generate_migration_code(&pool, "migrations", vec![block1, block2]);
557        let generated_str = generated.to_string();
558
559        // Both version checks should be present
560        assert!(generated_str.contains("version < 1u32"));
561        assert!(generated_str.contains("version < 2u32"));
562    }
563
564    #[test]
565    fn test_build_migration_block_generates_all_backends() {
566        let migration = MigrationPlan {
567            version: 1,
568            comment: None,
569            created_at: None,
570            actions: vec![MigrationAction::CreateTable {
571                table: "test_table".into(),
572                columns: vec![test_column("id")],
573                constraints: vec![],
574            }],
575        };
576
577        let mut baseline = Vec::new();
578        let result = build_migration_block(&migration, &mut baseline);
579        assert!(result.is_ok());
580
581        let block_str = result.unwrap().to_string();
582
583        // The generated block should have backend matching
584        assert!(block_str.contains("DatabaseBackend :: Postgres"));
585        assert!(block_str.contains("DatabaseBackend :: MySql"));
586        assert!(block_str.contains("DatabaseBackend :: Sqlite"));
587    }
588
589    #[test]
590    fn test_build_migration_block_with_delete_table() {
591        // First create the table
592        let create_migration = MigrationPlan {
593            version: 1,
594            comment: None,
595            created_at: None,
596            actions: vec![MigrationAction::CreateTable {
597                table: "temp_table".into(),
598                columns: vec![test_column("id")],
599                constraints: vec![],
600            }],
601        };
602
603        let mut baseline = Vec::new();
604        let _ = build_migration_block(&create_migration, &mut baseline);
605        assert_eq!(baseline.len(), 1);
606
607        // Now delete it
608        let delete_migration = MigrationPlan {
609            version: 2,
610            comment: None,
611            created_at: None,
612            actions: vec![MigrationAction::DeleteTable {
613                table: "temp_table".into(),
614            }],
615        };
616
617        let result = build_migration_block(&delete_migration, &mut baseline);
618        assert!(result.is_ok());
619        let block_str = result.unwrap().to_string();
620        assert!(block_str.contains("DROP TABLE"));
621
622        // Baseline should be empty after delete
623        assert_eq!(baseline.len(), 0);
624    }
625
626    #[test]
627    fn test_build_migration_block_with_index() {
628        let migration = MigrationPlan {
629            version: 1,
630            comment: None,
631            created_at: None,
632            actions: vec![MigrationAction::CreateTable {
633                table: "users".into(),
634                columns: vec![
635                    test_column("id"),
636                    ColumnDef {
637                        name: "email".into(),
638                        r#type: ColumnType::Simple(SimpleColumnType::Text),
639                        nullable: true,
640                        default: None,
641                        comment: None,
642                        primary_key: None,
643                        unique: None,
644                        index: Some(StrOrBoolOrArray::Bool(true)),
645                        foreign_key: None,
646                    },
647                ],
648                constraints: vec![],
649            }],
650        };
651
652        let mut baseline = Vec::new();
653        let result = build_migration_block(&migration, &mut baseline);
654        assert!(result.is_ok());
655
656        // Table should be normalized with index
657        let table = &baseline[0];
658        let normalized = table.clone().normalize();
659        assert!(normalized.is_ok());
660    }
661
662    #[test]
663    fn test_build_migration_block_error_nonexistent_table() {
664        // Try to add column to a table that doesn't exist - should fail
665        let migration = MigrationPlan {
666            version: 1,
667            comment: None,
668            created_at: None,
669            actions: vec![MigrationAction::AddColumn {
670                table: "nonexistent_table".into(),
671                column: Box::new(test_column("new_col")),
672                fill_with: None,
673            }],
674        };
675
676        let mut baseline = Vec::new();
677        let result = build_migration_block(&migration, &mut baseline);
678
679        assert!(result.is_err());
680        let err = result.unwrap_err();
681        assert!(err.contains("Failed to build queries for migration version 1"));
682    }
683
684    #[test]
685    fn test_vespertide_migration_impl_loading_error() {
686        // Save original CARGO_MANIFEST_DIR
687        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
688
689        // Remove CARGO_MANIFEST_DIR to trigger loading error
690        unsafe {
691            std::env::remove_var("CARGO_MANIFEST_DIR");
692        }
693
694        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
695        let output = vespertide_migration_impl(input);
696        let output_str = output.to_string();
697
698        // Should contain error about failed loading
699        assert!(
700            output_str.contains("Failed to load migrations at compile time"),
701            "Expected loading error, got: {}",
702            output_str
703        );
704
705        // Restore CARGO_MANIFEST_DIR
706        if let Some(val) = original {
707            unsafe {
708                std::env::set_var("CARGO_MANIFEST_DIR", val);
709            }
710        }
711    }
712
713    #[test]
714    fn test_vespertide_migration_impl_with_valid_project() {
715        use std::fs;
716
717        // Create a temporary directory with a valid vespertide project
718        let dir = tempdir().unwrap();
719        let project_dir = dir.path();
720
721        // Create vespertide.json config
722        let config_content = r#"{
723            "modelsDir": "models",
724            "migrationsDir": "migrations",
725            "tableNamingCase": "snake",
726            "columnNamingCase": "snake",
727            "modelFormat": "json"
728        }"#;
729        fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
730
731        // Create empty models and migrations directories
732        fs::create_dir_all(project_dir.join("models")).unwrap();
733        fs::create_dir_all(project_dir.join("migrations")).unwrap();
734
735        // Save original CARGO_MANIFEST_DIR and set to temp dir
736        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
737        unsafe {
738            std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
739        }
740
741        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
742        let output = vespertide_migration_impl(input);
743        let output_str = output.to_string();
744
745        // Should produce valid async code since there are no migrations
746        assert!(
747            output_str.contains("async"),
748            "Expected async block, got: {}",
749            output_str
750        );
751        assert!(
752            output_str.contains("CREATE TABLE IF NOT EXISTS"),
753            "Expected version table creation, got: {}",
754            output_str
755        );
756
757        // Restore CARGO_MANIFEST_DIR
758        if let Some(val) = original {
759            unsafe {
760                std::env::set_var("CARGO_MANIFEST_DIR", val);
761            }
762        } else {
763            unsafe {
764                std::env::remove_var("CARGO_MANIFEST_DIR");
765            }
766        }
767    }
768
769    #[test]
770    fn test_vespertide_migration_impl_with_migrations() {
771        use std::fs;
772
773        // Create a temporary directory with a valid vespertide project and migrations
774        let dir = tempdir().unwrap();
775        let project_dir = dir.path();
776
777        // Create vespertide.json config
778        let config_content = r#"{
779            "modelsDir": "models",
780            "migrationsDir": "migrations",
781            "tableNamingCase": "snake",
782            "columnNamingCase": "snake",
783            "modelFormat": "json"
784        }"#;
785        fs::write(project_dir.join("vespertide.json"), config_content).unwrap();
786
787        // Create models and migrations directories
788        fs::create_dir_all(project_dir.join("models")).unwrap();
789        fs::create_dir_all(project_dir.join("migrations")).unwrap();
790
791        // Create a migration file
792        let migration_content = r#"{
793            "version": 1,
794            "actions": [
795                {
796                    "type": "create_table",
797                    "table": "users",
798                    "columns": [
799                        {"name": "id", "type": "integer", "nullable": false}
800                    ],
801                    "constraints": []
802                }
803            ]
804        }"#;
805        fs::write(
806            project_dir.join("migrations").join("0001_initial.json"),
807            migration_content,
808        )
809        .unwrap();
810
811        // Save original CARGO_MANIFEST_DIR and set to temp dir
812        let original = std::env::var("CARGO_MANIFEST_DIR").ok();
813        unsafe {
814            std::env::set_var("CARGO_MANIFEST_DIR", project_dir);
815        }
816
817        let input: proc_macro2::TokenStream = "pool".parse().unwrap();
818        let output = vespertide_migration_impl(input);
819        let output_str = output.to_string();
820
821        // Should produce valid async code with migration
822        assert!(
823            output_str.contains("async"),
824            "Expected async block, got: {}",
825            output_str
826        );
827
828        // Restore CARGO_MANIFEST_DIR
829        if let Some(val) = original {
830            unsafe {
831                std::env::set_var("CARGO_MANIFEST_DIR", val);
832            }
833        } else {
834            unsafe {
835                std::env::remove_var("CARGO_MANIFEST_DIR");
836            }
837        }
838    }
839}