1use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token, parse_macro_input};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_planner::apply_action;
9use vespertide_query::{DatabaseBackend, build_plan_queries};
10
11struct MacroInput {
12 pool: Expr,
13 version_table: Option<String>,
14}
15
16impl Parse for MacroInput {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let pool = input.parse()?;
19 let mut version_table = None;
20
21 while !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 if input.is_empty() {
24 break;
25 }
26
27 let key: Ident = input.parse()?;
28 if key == "version_table" {
29 input.parse::<Token![=]>()?;
30 let value: syn::LitStr = input.parse()?;
31 version_table = Some(value.value());
32 } else {
33 return Err(syn::Error::new(
34 key.span(),
35 "unsupported option for vespertide_migration!",
36 ));
37 }
38 }
39
40 Ok(MacroInput {
41 pool,
42 version_table,
43 })
44 }
45}
46
47#[proc_macro]
49pub fn vespertide_migration(input: TokenStream) -> TokenStream {
50 let input = parse_macro_input!(input as MacroInput);
51 let pool = &input.pool;
52 let version_table = input
53 .version_table
54 .unwrap_or_else(|| "vespertide_version".to_string());
55
56 let migrations = match load_migrations_at_compile_time() {
58 Ok(migrations) => migrations,
59 Err(e) => {
60 return syn::Error::new(
61 proc_macro2::Span::call_site(),
62 format!("Failed to load migrations at compile time: {}", e),
63 )
64 .to_compile_error()
65 .into();
66 }
67 };
68 let _models = match load_models_at_compile_time() {
69 Ok(models) => models,
70 Err(e) => {
71 return syn::Error::new(
72 proc_macro2::Span::call_site(),
73 format!("Failed to load models at compile time: {}", e),
74 )
75 .to_compile_error()
76 .into();
77 }
78 };
79
80 let mut baseline_schema = Vec::new();
83 let mut migration_blocks = Vec::new();
84
85 for migration in &migrations {
86 let version = migration.version;
87
88 let queries = match build_plan_queries(migration, &baseline_schema) {
90 Ok(queries) => queries,
91 Err(e) => {
92 return syn::Error::new(
93 proc_macro2::Span::call_site(),
94 format!(
95 "Failed to build queries for migration version {}: {}",
96 version, e
97 ),
98 )
99 .to_compile_error()
100 .into();
101 }
102 };
103
104 for action in &migration.actions {
106 let _ = apply_action(&mut baseline_schema, action);
107 }
108
109 let mut pg_sqls = Vec::new();
112 let mut mysql_sqls = Vec::new();
113 let mut sqlite_sqls = Vec::new();
114
115 for q in &queries {
116 for stmt in &q.postgres {
117 pg_sqls.push(stmt.build(DatabaseBackend::Postgres));
118 }
119 for stmt in &q.mysql {
120 mysql_sqls.push(stmt.build(DatabaseBackend::MySql));
121 }
122 for stmt in &q.sqlite {
123 sqlite_sqls.push(stmt.build(DatabaseBackend::Sqlite));
124 }
125 }
126
127 let block = quote! {
129 if version < #version {
130 let txn = __pool.begin().await.map_err(|e| {
132 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
133 })?;
134
135 let sqls: &[&str] = match backend {
137 sea_orm::DatabaseBackend::Postgres => &[#(#pg_sqls),*],
138 sea_orm::DatabaseBackend::MySql => &[#(#mysql_sqls),*],
139 sea_orm::DatabaseBackend::Sqlite => &[#(#sqlite_sqls),*],
140 _ => &[#(#pg_sqls),*], };
142
143 for sql in sqls {
145 if !sql.is_empty() {
146 let stmt = sea_orm::Statement::from_string(backend, *sql);
147 txn.execute_raw(stmt).await.map_err(|e| {
148 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
149 })?;
150 }
151 }
152
153 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
155 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
156 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
157 txn.execute_raw(stmt).await.map_err(|e| {
158 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
159 })?;
160
161 txn.commit().await.map_err(|e| {
163 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
164 })?;
165 }
166 };
167
168 migration_blocks.push(block);
169 }
170
171 let generated = quote! {
173 async {
174 use sea_orm::{ConnectionTrait, TransactionTrait};
175 let __pool = #pool;
176 let version_table = #version_table;
177 let backend = __pool.get_database_backend();
178
179 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
182 let create_table_sql = format!(
183 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
184 version_table
185 );
186 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
187 __pool.execute_raw(stmt).await.map_err(|e| {
188 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
189 })?;
190
191 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
193 let stmt = sea_orm::Statement::from_string(backend, select_sql);
194 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
195 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
196 })?;
197
198 let mut version = version_result
199 .and_then(|row| row.try_get::<i32>("", "version").ok())
200 .unwrap_or(0) as u32;
201
202 #(#migration_blocks)*
204
205 Ok::<(), ::vespertide::MigrationError>(())
206 }
207 };
208
209 generated.into()
210}