vespertide_macro/
lib.rs

1// MigrationOptions and MigrationError are now in vespertide-core
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token, parse_macro_input};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_query::{DatabaseBackend, build_plan_queries};
9
10struct MacroInput {
11    pool: Expr,
12    version_table: Option<String>,
13}
14
15impl Parse for MacroInput {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let pool = input.parse()?;
18        let mut version_table = None;
19
20        while !input.is_empty() {
21            input.parse::<Token![,]>()?;
22            if input.is_empty() {
23                break;
24            }
25
26            let key: Ident = input.parse()?;
27            if key == "version_table" {
28                input.parse::<Token![=]>()?;
29                let value: syn::LitStr = input.parse()?;
30                version_table = Some(value.value());
31            } else {
32                return Err(syn::Error::new(
33                    key.span(),
34                    "unsupported option for vespertide_migration!",
35                ));
36            }
37        }
38
39        Ok(MacroInput {
40            pool,
41            version_table,
42        })
43    }
44}
45
46/// Zero-runtime migration entry point.
47#[proc_macro]
48pub fn vespertide_migration(input: TokenStream) -> TokenStream {
49    let input = parse_macro_input!(input as MacroInput);
50    let pool = &input.pool;
51    let version_table = input
52        .version_table
53        .unwrap_or_else(|| "vespertide_version".to_string());
54
55    // Load migration files and build SQL at compile time
56    let migrations = match load_migrations_at_compile_time() {
57        Ok(migrations) => migrations,
58        Err(e) => {
59            return syn::Error::new(
60                proc_macro2::Span::call_site(),
61                format!("Failed to load migrations at compile time: {}", e),
62            )
63            .to_compile_error()
64            .into();
65        }
66    };
67    let models = match load_models_at_compile_time() {
68        Ok(models) => models,
69        Err(e) => {
70            return syn::Error::new(
71                proc_macro2::Span::call_site(),
72                format!("Failed to load models at compile time: {}", e),
73            )
74            .to_compile_error()
75            .into();
76        }
77    };
78
79    // Build SQL for each migration
80    let mut migration_blocks = Vec::new();
81    for migration in &migrations {
82        let version = migration.version;
83        let queries = match build_plan_queries(migration, &models) {
84            Ok(queries) => queries,
85            Err(e) => {
86                return syn::Error::new(
87                    proc_macro2::Span::call_site(),
88                    format!(
89                        "Failed to build queries for migration version {}: {}",
90                        version, e
91                    ),
92                )
93                .to_compile_error()
94                .into();
95            }
96        };
97
98        // Pre-generate SQL for all backends at compile time
99        let sql_statements: Vec<_> = queries
100            .iter()
101            .map(|q| {
102                let pg_sql = q
103                    .postgres
104                    .iter()
105                    .map(|q| q.build(DatabaseBackend::Postgres))
106                    .collect::<Vec<_>>()
107                    .join(";\n");
108                let mysql_sql = q
109                    .mysql
110                    .iter()
111                    .map(|q| q.build(DatabaseBackend::MySql))
112                    .collect::<Vec<_>>()
113                    .join(";\n");
114                let sqlite_sql = q
115                    .sqlite
116                    .iter()
117                    .map(|q| q.build(DatabaseBackend::Sqlite))
118                    .collect::<Vec<_>>()
119                    .join(";\n");
120
121                quote! {
122                    match backend {
123                        sea_orm::DatabaseBackend::Postgres => #pg_sql,
124                        sea_orm::DatabaseBackend::MySql => #mysql_sql,
125                        sea_orm::DatabaseBackend::Sqlite => #sqlite_sql,
126                        _ => #pg_sql, // Fallback to PostgreSQL syntax for unknown backends
127                    }
128                }
129            })
130            .collect();
131
132        // Generate version guard and SQL execution block
133        let block = quote! {
134            if version < #version {
135                // Begin transaction
136                let txn = __pool.begin().await.map_err(|e| {
137                    ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
138                })?;
139
140                // Execute SQL statements
141                #(
142                    {
143                        let sql: &str = #sql_statements;
144                        let stmt = sea_orm::Statement::from_string(backend, sql);
145                        txn.execute_raw(stmt).await.map_err(|e| {
146                            ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
147                        })?;
148                    }
149                )*
150
151                // Insert version record for this migration
152                let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
153                let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
154                let stmt = sea_orm::Statement::from_string(backend, insert_sql);
155                txn.execute_raw(stmt).await.map_err(|e| {
156                    ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
157                })?;
158
159                // Commit transaction
160                txn.commit().await.map_err(|e| {
161                    ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
162                })?;
163            }
164        };
165
166        migration_blocks.push(block);
167    }
168
169    // Emit final generated async block
170    let generated = quote! {
171        async {
172            use sea_orm::{ConnectionTrait, TransactionTrait};
173            let __pool = #pool;
174            let version_table = #version_table;
175            let backend = __pool.get_database_backend();
176
177            // Create version table if it does not exist
178            // Table structure: version (INTEGER PRIMARY KEY), created_at (timestamp)
179            let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
180            let create_table_sql = format!(
181                "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
182                version_table
183            );
184            let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
185            __pool.execute_raw(stmt).await.map_err(|e| {
186                ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
187            })?;
188
189            // Read current maximum version (latest applied migration)
190            let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
191            let stmt = sea_orm::Statement::from_string(backend, select_sql);
192            let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
193                ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
194            })?;
195
196            let mut version = version_result
197                .and_then(|row| row.try_get::<i32>("", "version").ok())
198                .unwrap_or(0) as u32;
199
200            // Execute each migration block
201            #(#migration_blocks)*
202
203            Ok::<(), ::vespertide::MigrationError>(())
204        }
205    };
206
207    generated.into()
208}