vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token, parse_macro_input};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12    pool: Expr,
13    version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let pool = input.parse()?;
19        let mut version_table = None;
20
21        while !input.is_empty() {
22            input.parse::<Token![,]>()?;
23            if input.is_empty() {
24                break;
25            }
26
27            let key: Ident = input.parse()?;
28            if key == "version_table" {
29                input.parse::<Token![=]>()?;
30                let value: syn::LitStr = input.parse()?;
31                version_table = Some(value.value());
32            } else {
33                return Err(syn::Error::new(
34                    key.span(),
35                    "unsupported option for vespertide_migration!",
36                ));
37            }
38        }
39
40        Ok(MacroInput {
41            pool,
42            version_table,
43        })
44    }
45}
46
47/// Zero-runtime migration entry point.
48#[proc_macro]
49pub fn vespertide_migration(input: TokenStream) -> TokenStream {
50    let input = parse_macro_input!(input as MacroInput);
51    let pool = &input.pool;
52    let version_table = input
53        .version_table
54        .unwrap_or_else(|| "vespertide_version".to_string());
55
56    // Load migration files and build SQL at compile time
57    let migrations = match load_migrations_at_compile_time() {
58        Ok(migrations) => migrations,
59        Err(e) => {
60            return syn::Error::new(
61                proc_macro2::Span::call_site(),
62                format!("Failed to load migrations at compile time: {}", e),
63            )
64            .to_compile_error()
65            .into();
66        }
67    };
68    let _models = match load_models_at_compile_time() {
69        Ok(models) => models,
70        Err(e) => {
71            return syn::Error::new(
72                proc_macro2::Span::call_site(),
73                format!("Failed to load models at compile time: {}", e),
74            )
75            .to_compile_error()
76            .into();
77        }
78    };
79
80    // Build SQL for each migration using incremental baseline schema
81    // This is the same approach as cmd_log: start with empty schema and apply each migration
82    let mut baseline_schema = Vec::new();
83    let mut migration_blocks = Vec::new();
84
85    for migration in &migrations {
86        let version = migration.version;
87
88        // Use the current baseline schema (from all previous migrations)
89        let queries = match build_plan_queries(migration, &baseline_schema) {
90            Ok(queries) => queries,
91            Err(e) => {
92                return syn::Error::new(
93                    proc_macro2::Span::call_site(),
94                    format!(
95                        "Failed to build queries for migration version {}: {}",
96                        version, e
97                    ),
98                )
99                .to_compile_error()
100                .into();
101            }
102        };
103
104        // Update baseline schema incrementally by applying each action
105        for action in &migration.actions {
106            let _ = apply_action(&mut baseline_schema, action);
107        }
108
109        // Pre-generate SQL for all backends at compile time
110        // Each query may produce multiple SQL statements, so we flatten them
111        let mut pg_sqls = Vec::new();
112        let mut mysql_sqls = Vec::new();
113        let mut sqlite_sqls = Vec::new();
114
115        for q in &queries {
116            for stmt in &q.postgres {
117                pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
118            }
119            for stmt in &q.mysql {
120                mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
121            }
122            for stmt in &q.sqlite {
123                sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
124            }
125        }
126
127        // Generate version guard and SQL execution block
128        let block = quote! {
129            if version < #version {
130                // Begin transaction
131                let txn = __pool.begin().await.map_err(|e| {
132                    ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
133                })?;
134
135                // Select SQL statements based on backend
136                let sqls: &[&str] = match backend {
137                    sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
138                    sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
139                    sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
140                    _ => &[#(#pg_sqls),*], // Fallback to PostgreSQL syntax for unknown backends
141                };
142
143                // Execute SQL statements
144                for sql in sqls {
145                    if !sql.is_empty() {
146                        let stmt = sea_orm::Statement::from_string(backend, *sql);
147                        txn.execute_raw(stmt).await.map_err(|e| {
148                            ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
149                        })?;
150                    }
151                }
152
153                // Insert version record for this migration
154                let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
155                let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
156                let stmt = sea_orm::Statement::from_string(backend, insert_sql);
157                txn.execute_raw(stmt).await.map_err(|e| {
158                    ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
159                })?;
160
161                // Commit transaction
162                txn.commit().await.map_err(|e| {
163                    ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
164                })?;
165            }
166        };
167
168        migration_blocks.push(block);
169    }
170
171    // Emit final generated async block
172    let generated = quote! {
173        async {
174            use sea_orm::{ConnectionTrait, TransactionTrait};
175            let __pool = #pool;
176            let version_table = #version_table;
177            let backend = __pool.get_database_backend();
178
179            // Create version table if it does not exist
180            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
181            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
182            let create_table_sql = format!(
183                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
184                version_table
185            );
186            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
187            __pool.execute_raw(stmt).await.map_err(|e| {
188                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
189            })?;
190
191            // Read current maximum version (latest applied migration)
192            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
193            let stmt = sea_orm::Statement::from_string(backend, select_sql);
194            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
195                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
196            })?;
197
198            let mut version = version_result
199                .and_then(|row| row.try_get::<i32>("", "version").ok())
200                .unwrap_or(0) as u32;
201
202            // Execute each migration block
203            #(#migration_blocks)*
204
205            Ok::<(), ::vespertide::MigrationError>(())
206        }
207    };
208
209    generated.into()
210}