1#![doc = include_str!("../README.md")]
2
3use std::future::Future;
4use std::pin::Pin;
5use thiserror::Error;
6use turso::{Connection, Error};
7
8pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
10
11#[derive(Error, Debug)]
12pub enum MigrationError {
13 #[error("Turso error: {0}")]
14 Turso(#[from] Error),
15 #[error("Database has user_version {0}, but it must be between 0 and {1}")]
16 InvalidUserVersion(i32, usize),
17}
18
19#[derive(Clone, Copy)]
22pub enum Migration<'a> {
23 Sql {
24 name: &'a str,
25 sql: &'a str,
26 },
27 Fn {
28 name: &'a str,
29 f: fn(&Connection) -> MigrationFuture,
30 },
31}
32
33impl<'a> Migration<'a> {
34 pub const fn up(name: &'a str, sql: &'a str) -> Self {
36 Self::Sql { name, sql }
37 }
38
39 pub const fn up_fn(name: &'a str, f: fn(&Connection) -> MigrationFuture) -> Self {
43 Self::Fn { name, f }
44 }
45}
46
47#[macro_export]
60macro_rules! up_file {
61 ($path:literal) => {
62 $crate::Migration::up($path, include_str!($path))
63 };
64}
65
66#[macro_export]
82macro_rules! up_fn {
83 ($name:expr, $func:path) => {{
84 fn wrapper(conn: &turso::Connection) -> $crate::MigrationFuture<'_> {
85 Box::pin($func(conn))
86 }
87 $crate::Migration::up_fn($name, wrapper)
88 }};
89}
90
91pub struct Migrations<'a> {
93 migrations: &'a [Migration<'a>],
94}
95
96impl<'a> Migrations<'a> {
97 pub const fn new(migrations: &'a [Migration<'a>]) -> Self {
98 Self { migrations }
99 }
100
101 pub async fn to_latest(&self, conn: &mut Connection) -> Result<usize, (usize, MigrationError)> {
106 let current_version = match get_user_version(conn).await {
107 Ok(v) => v,
108 Err(e) => return Err((0, e.into())),
109 };
110 let target_version = self.migrations.len() as i32;
111 if current_version == target_version {
112 return Ok(0);
113 }
114 if current_version < 0 || current_version > target_version {
115 return Err((
116 0,
117 MigrationError::InvalidUserVersion(current_version, self.migrations.len()),
118 ));
119 }
120
121 let mut applied_count = 0;
122 for (i, migration) in self
123 .migrations
124 .iter()
125 .enumerate()
126 .skip(current_version as usize)
127 {
128 let version = (i + 1) as i32;
129
130 let tx = match conn.transaction().await {
132 Ok(tx) => tx,
133 Err(e) => return Err((applied_count, e.into())),
134 };
135
136 let result = match migration {
138 Migration::Sql { name: _name, sql } => tx.execute_batch(sql).await.map(|_| ()),
139 Migration::Fn { name: _name, f } => f(&tx).await,
140 };
141
142 #[cfg(feature = "tracing")]
144 let migration_name_log = match migration {
145 Migration::Sql { name, .. } => *name,
146 Migration::Fn { name, .. } => *name,
147 };
148
149 if let Err(e) = result {
150 #[cfg(feature = "tracing")]
151 tracing::error!(error = ?e, migration_name = %migration_name_log, "Migration failed");
152 return Err((applied_count, e.into()));
153 }
154
155 if let Err(e) = set_user_version(&tx, version).await {
157 #[cfg(feature = "tracing")]
158 tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to update user_version");
159 return Err((applied_count, e.into()));
160 }
161
162 if let Err(e) = tx.commit().await {
164 #[cfg(feature = "tracing")]
165 tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to commit transaction");
166 return Err((applied_count, e.into()));
167 }
168
169 #[cfg(feature = "tracing")]
170 tracing::debug!(migration_name = %migration_name_log, "Migration applied successfully");
171
172 applied_count += 1;
173 }
174
175 Ok(applied_count)
176 }
177
178 pub async fn run_all_in_memory(&self) -> Result<Connection, MigrationError> {
181 let db = turso::Builder::new_local(":memory:").build().await?;
182 let mut conn = db.connect()?;
183 if let Err((_, e)) = self.to_latest(&mut conn).await {
184 return Err(e);
185 }
186 Ok(conn)
187 }
188}
189
190async fn get_user_version(conn: &Connection) -> Result<i32, Error> {
191 let version = std::sync::atomic::AtomicI32::new(0);
192 conn.pragma_query("user_version", |row| {
193 let v = row.get::<i32>(0).unwrap();
194 version.store(v, std::sync::atomic::Ordering::SeqCst);
195 Ok(())
196 })
197 .await?;
198 Ok(version.load(std::sync::atomic::Ordering::SeqCst))
199}
200
201async fn set_user_version(conn: &Connection, version: i32) -> Result<(), Error> {
202 conn.pragma_update("user_version", version).await?;
203 Ok(())
204}
205
206#[cfg(test)]
207mod test {
208 use super::*;
209 use turso::Builder;
210
211 const MIGRATIONS: Migrations = Migrations::new(&[
212 Migration::up("001", "CREATE TABLE friend(name TEXT NOT NULL);"),
213 up_fn!("002", my_complex_migration),
214 ]);
215
216 async fn my_complex_migration(conn: &Connection) -> turso::Result<()> {
218 conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
219 .await?;
220 conn.execute("INSERT INTO users (name) VALUES ('Alice')", ())
221 .await?;
222 Ok(())
223 }
224
225 async fn get_in_memory_conn() -> Connection {
226 let db = Builder::new_local(":memory:").build().await.unwrap();
227 db.connect().unwrap()
228 }
229
230 #[tokio::test]
231 async fn test_migrations_success() {
232 let mut conn = get_in_memory_conn().await;
233
234 let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
236 assert_eq!(applied_count, 2);
237
238 let version = get_user_version(&conn).await.unwrap();
240 assert_eq!(version, 2);
241
242 conn.execute("INSERT INTO friend (name) VALUES ('test')", ())
244 .await
245 .unwrap();
246 conn.execute("INSERT INTO users (name) VALUES ('test')", ())
247 .await
248 .unwrap();
249 }
250
251 #[tokio::test]
252 async fn test_idempotency() {
253 let mut conn = get_in_memory_conn().await;
254
255 let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
257 assert_eq!(applied_count, 2);
258
259 let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
261 assert_eq!(applied_count, 0);
262 }
263
264 #[tokio::test]
265 async fn test_partial_migration() {
266 let mut conn = get_in_memory_conn().await;
267
268 conn.pragma_update("user_version", 1).await.unwrap();
270
271 let count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
272 assert_eq!(count, 1);
273
274 let version = get_user_version(&conn).await.unwrap();
275 assert_eq!(version, 2);
276 }
277
278 #[tokio::test]
279 async fn test_validate_helper() {
280 let conn = MIGRATIONS.run_all_in_memory().await;
282 assert!(conn.is_ok());
283
284 let conn = conn.unwrap();
286 let version = get_user_version(&conn).await.unwrap();
287 assert_eq!(version, 2);
288 }
289
290 #[tokio::test]
291 async fn test_negative_user_version() {
292 let mut conn = get_in_memory_conn().await;
293
294 conn.pragma_update("user_version", -5).await.unwrap();
296
297 let result = MIGRATIONS.to_latest(&mut conn).await;
298
299 match result {
300 Err((0, MigrationError::InvalidUserVersion(v, 2))) => assert_eq!(v, -5),
301 _ => panic!("Expected InvalidUserVersion error, got {:?}", result),
302 }
303 }
304
305 #[tokio::test]
306 async fn test_user_version_too_high() {
307 let mut conn = get_in_memory_conn().await;
308
309 conn.pragma_update("user_version", 10).await.unwrap();
311
312 let result = MIGRATIONS.to_latest(&mut conn).await;
313
314 match result {
315 Err((0, MigrationError::InvalidUserVersion(v, max))) => {
316 assert_eq!(v, 10);
317 assert_eq!(max, 2);
318 }
319 _ => panic!("Expected InvalidUserVersion error, got {:?}", result),
320 }
321 }
322
323 #[tokio::test]
324 async fn test_failing_migration() {
325 let mut conn = get_in_memory_conn().await;
326
327 const BROKEN_MIGRATIONS: Migrations = Migrations::new(&[
328 Migration::up("001", "CREATE TABLE ok (id int)"),
329 Migration::up("002", "SELECT * FROM non_existent_table"), ]);
331
332 let result = BROKEN_MIGRATIONS.to_latest(&mut conn).await;
333 match result {
334 Err((1, MigrationError::Turso(_))) => {} _ => panic!("Expected Turso error, got {:?}", result),
336 }
337
338 let version = get_user_version(&conn).await.unwrap();
340 assert_eq!(version, 1);
341 }
342
343 #[tokio::test]
344 async fn test_up_file_macro() {
345 let mut conn = get_in_memory_conn().await;
346
347 const FILE_MIGRATIONS: Migrations =
348 Migrations::new(&[up_file!("../tests/migration-files/001_test.sql")]);
349
350 let result = FILE_MIGRATIONS.to_latest(&mut conn).await;
351 assert!(result.is_ok(), "File migration failed: {:?}", result.err());
352
353 let version = get_user_version(&conn).await.unwrap();
355 assert_eq!(version, 1);
356
357 conn.execute("INSERT INTO file_test (id) VALUES (1)", ())
359 .await
360 .unwrap();
361
362 conn.execute("INSERT INTO file_test_2 (id) VALUES (1)", ())
363 .await
364 .unwrap();
365 }
366
367 #[tokio::test]
368 async fn test_dynamic_migration_name() {
369 let mut conn = get_in_memory_conn().await;
370
371 let dynamic_name = String::from("003_dynamic");
372 let dynamic_sql = String::from("CREATE TABLE dynamic(id int)");
373
374 let migrations = [Migration::up(&dynamic_name, &dynamic_sql)];
375 let migrations = Migrations::new(&migrations);
376
377 let applied = migrations.to_latest(&mut conn).await.unwrap();
378 assert_eq!(applied, 1);
379
380 let version = get_user_version(&conn).await.unwrap();
381 assert_eq!(version, 1);
382 }
383}