1use proc_macro::TokenStream;
4use quote::quote;
5use std::env;
6use std::fs;
7use std::path::PathBuf;
8use syn::parse::{Parse, ParseStream};
9use syn::{Expr, Ident, Token, parse_macro_input};
10use vespertide_config::VespertideConfig;
11use vespertide_core::MigrationPlan;
12use vespertide_query::build_plan_queries;
13
14struct MacroInput {
15 pool: Expr,
16 version_table: Option<String>,
17}
18
19impl Parse for MacroInput {
20 fn parse(input: ParseStream) -> syn::Result<Self> {
21 let pool = input.parse()?;
22 let mut version_table = None;
23
24 while !input.is_empty() {
25 input.parse::<Token![,]>()?;
26 if input.is_empty() {
27 break;
28 }
29
30 let key: Ident = input.parse()?;
31 if key == "version_table" {
32 input.parse::<Token![=]>()?;
33 let value: syn::LitStr = input.parse()?;
34 version_table = Some(value.value());
35 } else {
36 return Err(syn::Error::new(
37 key.span(),
38 "unsupported option for vespertide_migration!",
39 ));
40 }
41 }
42
43 Ok(MacroInput {
44 pool,
45 version_table,
46 })
47 }
48}
49
50#[proc_macro]
52pub fn vespertide_migration(input: TokenStream) -> TokenStream {
53 let input = parse_macro_input!(input as MacroInput);
54 let pool = &input.pool;
55 let version_table = input
56 .version_table
57 .unwrap_or_else(|| "vespertide_version".to_string());
58
59 let migrations = match load_migrations_at_compile_time() {
61 Ok(migrations) => migrations,
62 Err(e) => {
63 return syn::Error::new(
64 proc_macro2::Span::call_site(),
65 format!("Failed to load migrations at compile time: {}", e),
66 )
67 .to_compile_error()
68 .into();
69 }
70 };
71
72 let mut migration_blocks = Vec::new();
74 for migration in &migrations {
75 let version = migration.version;
76 let queries = match build_plan_queries(migration) {
77 Ok(queries) => queries,
78 Err(e) => {
79 return syn::Error::new(
80 proc_macro2::Span::call_site(),
81 format!(
82 "Failed to build queries for migration version {}: {}",
83 version, e
84 ),
85 )
86 .to_compile_error()
87 .into();
88 }
89 };
90
91 let sql_statements: Vec<_> = queries
93 .iter()
94 .map(|q| {
95 let sql = &q.sql;
96 let binds = &q.binds;
97 let value_tokens = binds.iter().map(|b| {
98 quote! { sea_orm::Value::String(Some(#b.to_string())) }
99 });
100 quote! { (#sql, vec![#(#value_tokens),*]) }
101 })
102 .collect();
103
104 let block = quote! {
106 if version < #version {
107 let txn = __pool.begin().await.map_err(|e| {
109 ::vespertide::MigrationError::DatabaseError(format!("Failed to begin transaction: {}", e))
110 })?;
111
112 #(
114 {
115 let (sql, values) = #sql_statements;
116 let stmt = sea_orm::Statement::from_sql_and_values(backend, sql, values);
117 txn.execute_raw(stmt).await.map_err(|e| {
118 ::vespertide::MigrationError::DatabaseError(format!("Failed to execute SQL: {}", e))
119 })?;
120 }
121 )*
122
123 let stmt = sea_orm::Statement::from_sql_and_values(
125 backend,
126 &format!("INSERT INTO {} (version) VALUES (?)", version_table),
127 vec![sea_orm::Value::Int(Some(#version as i32))],
128 );
129 txn.execute_raw(stmt).await.map_err(|e| {
130 ::vespertide::MigrationError::DatabaseError(format!("Failed to insert version: {}", e))
131 })?;
132
133 txn.commit().await.map_err(|e| {
135 ::vespertide::MigrationError::DatabaseError(format!("Failed to commit transaction: {}", e))
136 })?;
137 }
138 };
139
140 migration_blocks.push(block);
141 }
142
143 let generated = quote! {
145 async {
146 use sea_orm::{ConnectionTrait, TransactionTrait};
147 let __pool = #pool;
148 let version_table = #version_table;
149 let backend = __pool.get_database_backend();
150
151 let create_table_sql = format!(
154 "CREATE TABLE IF NOT EXISTS {} (version INTEGER PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)",
155 version_table
156 );
157 let stmt = sea_orm::Statement::from_string(backend, create_table_sql);
158 __pool.execute_raw(stmt).await.map_err(|e| {
159 ::vespertide::MigrationError::DatabaseError(format!("Failed to create version table: {}", e))
160 })?;
161
162 let stmt = sea_orm::Statement::from_string(
164 backend,
165 format!("SELECT MAX(version) as version FROM {}", version_table),
166 );
167 let version_result = __pool.query_one_raw(stmt).await.map_err(|e| {
168 ::vespertide::MigrationError::DatabaseError(format!("Failed to read version: {}", e))
169 })?;
170
171 let mut version = version_result
172 .and_then(|row| row.try_get::<i32>("", "version").ok())
173 .unwrap_or(0) as u32;
174
175 #(#migration_blocks)*
177
178 Ok::<(), ::vespertide::MigrationError>(())
179 }
180 };
181
182 generated.into()
183}
184
185fn load_migrations_at_compile_time() -> Result<Vec<MigrationPlan>, Box<dyn std::error::Error>> {
186 let manifest_dir = env::var("CARGO_MANIFEST_DIR")
188 .map_err(|_| "CARGO_MANIFEST_DIR environment variable not set")?;
189 let project_root = PathBuf::from(manifest_dir);
190
191 let config_path = project_root.join("vespertide.json");
193 let config: VespertideConfig = if config_path.exists() {
194 let content = fs::read_to_string(&config_path)?;
195 serde_json::from_str(&content)?
196 } else {
197 VespertideConfig::default()
199 };
200
201 let migrations_dir = project_root.join(config.migrations_dir());
203 if !migrations_dir.exists() {
204 return Ok(Vec::new());
205 }
206
207 let mut plans = Vec::new();
208 let entries = fs::read_dir(&migrations_dir)?;
209
210 for entry in entries {
211 let entry = entry?;
212 let path = entry.path();
213 if path.is_file() {
214 let ext = path.extension().and_then(|s| s.to_str());
215 if ext == Some("json") || ext == Some("yaml") || ext == Some("yml") {
216 let content = fs::read_to_string(&path)?;
217
218 let plan: MigrationPlan = if ext == Some("json") {
219 serde_json::from_str(&content)?
220 } else {
221 serde_yaml::from_str(&content)?
222 };
223
224 plans.push(plan);
225 }
226 }
227 }
228
229 plans.sort_by_key(|p| p.version);
231 Ok(plans)
232}