1use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{Expr, Ident, Token, parse_macro_input};
7use vespertide_loader::{load_migrations_at_compile_time, load_models_at_compile_time};
8use vespertide_query::{DatabaseBackend, build_plan_queries};
9
10struct MacroInput {
11 pool: Expr,
12 version_table: Option<String>,
13}
14
15impl Parse for MacroInput {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let pool = input.parse()?;
18 let mut version_table = None;
19
20 while !input.is_empty() {
21 input.parse::<Token![,]>()?;
22 if input.is_empty() {
23 break;
24 }
25
26 let key: Ident = input.parse()?;
27 if key == "version_table" {
28 input.parse::<Token![=]>()?;
29 let value: syn::LitStr = input.parse()?;
30 version_table = Some(value.value());
31 } else {
32 return Err(syn::Error::new(
33 key.span(),
34 "unsupported option for vespertide_migration!",
35 ));
36 }
37 }
38
39 Ok(MacroInput {
40 pool,
41 version_table,
42 })
43 }
44}
45
46#[proc_macro]
48pub fn vespertide_migration(input: TokenStream) -> TokenStream {
49 let input = parse_macro_input!(input as MacroInput);
50 let pool = &input.pool;
51 let version_table = input
52 .version_table
53 .unwrap_or_else(|| "vespertide_version".to_string());
54
55 let migrations = match load_migrations_at_compile_time() {
57 Ok(migrations) => migrations,
58 Err(e) => {
59 return syn::Error::new(
60 proc_macro2::Span::call_site(),
61 format!("Failed to load migrations at compile time: {}", e),
62 )
63 .to_compile_error()
64 .into();
65 }
66 };
67 let models = match load_models_at_compile_time() {
68 Ok(models) => models,
69 Err(e) => {
70 return syn::Error::new(
71 proc_macro2::Span::call_site(),
72 format!("Failed to load models at compile time: {}", e),
73 )
74 .to_compile_error()
75 .into();
76 }
77 };
78
79 let mut migration_blocks = Vec::new();
81 for migration in &migrations {
82 let version = migration.version;
83 let queries = match build_plan_queries(migration, &models) {
84 Ok(queries) => queries,
85 Err(e) => {
86 return syn::Error::new(
87 proc_macro2::Span::call_site(),
88 format!(
89 "Failed to build queries for migration version {}: {}",
90 version, e
91 ),
92 )
93 .to_compile_error()
94 .into();
95 }
96 };
97
98 let sql_statements: Vec<_> = queries
100 .iter()
101 .map(|q| {
102 let pg_sql = q
103 .postgres
104 .iter()
105 .map(|q| q.build(DatabaseBackend::Postgres))
106 .collect::<Vec<_>>()
107 .join(";\n");
108 let mysql_sql = q
109 .mysql
110 .iter()
111 .map(|q| q.build(DatabaseBackend::MySql))
112 .collect::<Vec<_>>()
113 .join(";\n");
114 let sqlite_sql = q
115 .sqlite
116 .iter()
117 .map(|q| q.build(DatabaseBackend::Sqlite))
118 .collect::<Vec<_>>()
119 .join(";\n");
120
121 quote! {
122 match backend {
123 sea_orm::DatabaseBackend::Postgres => #pg_sql,
124 sea_orm::DatabaseBackend::MySql => #mysql_sql,
125 sea_orm::DatabaseBackend::Sqlite => #sqlite_sql,
126 _ => #pg_sql, }
128 }
129 })
130 .collect();
131
132 let block = quote! {
134 if version < #version {
135 let txn = __pool.begin().await.map_err(|e| {
137 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
138 })?;
139
140 #(
142 {
143 let sql: &str = #sql_statements;
144 let stmt = sea_orm::Statement::from_string(backend, sql);
145 txn.execute_raw(stmt).await.map_err(|e| {
146 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL '{}': {}", sql, e))
147 })?;
148 }
149 )*
150
151 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
153 let insert_sql = format!("INSERT INTO {q}{}{q} (version) VALUES ({})", version_table, #version);
154 let stmt = sea_orm::Statement::from_string(backend, insert_sql);
155 txn.execute_raw(stmt).await.map_err(|e| {
156 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
157 })?;
158
159 txn.commit().await.map_err(|e| {
161 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
162 })?;
163 }
164 };
165
166 migration_blocks.push(block);
167 }
168
169 let generated = quote! {
171 async {
172 use sea_orm::{ConnectionTrait, TransactionTrait};
173 let __pool = #pool;
174 let version_table = #version_table;
175 let backend = __pool.get_database_backend();
176
177 let q = if matches!(backend, sea_orm::DatabaseBackend::MySql) { '`' } else { '"' };
180 let create_table_sql = format!(
181 "CREATE TABLE IF NOT EXISTS {q}{}{q} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
182 version_table
183 );
184 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
185 __pool.execute_raw(stmt).await.map_err(|e| {
186 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
187 })?;
188
189 let select_sql = format!("SELECT MAX(version) as version FROM {q}{}{q}", version_table);
191 let stmt = sea_orm::Statement::from_string(backend, select_sql);
192 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
193 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
194 })?;
195
196 let mut version = version_result
197 .and_then(|row| row.try_get::<i32>("", "version").ok())
198 .unwrap_or(0) as u32;
199
200 #(#migration_blocks)*
202
203 Ok::<(), ::vespertide::MigrationError>(())
204 }
205 };
206
207 generated.into()
208}