Skip to main content

turso_migrate/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::future::Future;
4use std::pin::Pin;
5use thiserror::Error;
6use turso::{Connection, Error};
7
8/// Type alias for the future returned by migration functions.
9pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
10
11#[derive(Error, Debug)]
12pub enum MigrationError {
13    #[error("Turso error: {0}")]
14    Turso(#[from] Error),
15    #[error("Database has user_version {0}, but it must be between 0 and {1}")]
16    InvalidUserVersion(i32, usize),
17}
18
19/// A single migration step.
20/// Can be a raw SQL string, an async function, or a file.
21#[derive(Clone, Copy)]
22pub enum Migration<'a> {
23    Sql {
24        name: &'a str,
25        sql: &'a str,
26    },
27    Fn {
28        name: &'a str,
29        f: fn(&Connection) -> MigrationFuture,
30    },
31}
32
33impl<'a> Migration<'a> {
34    /// Creates a migration from a raw SQL string.
35    pub const fn up(name: &'a str, sql: &'a str) -> Self {
36        Self::Sql { name, sql }
37    }
38
39    /// Creates a migration from a rust function.
40    /// The function must take a `&Connection` and return a `MigrationFuture`.
41    /// Use the `up_fn!` macro to directly take an async function.
42    pub const fn up_fn(name: &'a str, f: fn(&Connection) -> MigrationFuture) -> Self {
43        Self::Fn { name, f }
44    }
45}
46
47/// Helper macro to create a migration from a file.
48/// Uses `include_str!` to embed the SQL content.
49///
50/// # Example
51///
52/// ```rust
53/// use turso_migrate::{up_file, Migration, Migrations};
54///
55/// const MIGRATIONS: Migrations = Migrations::new(&[
56///     up_file!("../tests/migration-files/001_test.sql"),
57/// ]);
58/// ```
59#[macro_export]
60macro_rules! up_file {
61    ($path:literal) => {
62        $crate::Migration::up($path, include_str!($path))
63    };
64}
65
66/// Helper macro to create a migration from an async function.
67///
68/// # Example
69///
70/// ```rust
71/// use turso_migrate::{up_fn, Migration, Migrations};
72///
73/// async fn my_migration(conn: &turso::Connection) -> turso::Result<()> {
74///     Ok(())
75/// }
76///
77/// const MIGRATIONS: Migrations = Migrations::new(&[
78///     up_fn!("001", my_migration),
79/// ]);
80/// ```
81#[macro_export]
82macro_rules! up_fn {
83    ($name:expr, $func:path) => {{
84        fn wrapper(conn: &turso::Connection) -> $crate::MigrationFuture<'_> {
85            Box::pin($func(conn))
86        }
87        $crate::Migration::up_fn($name, wrapper)
88    }};
89}
90
91/// Manages the application of migrations.
92pub struct Migrations<'a> {
93    migrations: &'a [Migration<'a>],
94}
95
96impl<'a> Migrations<'a> {
97    pub const fn new(migrations: &'a [Migration<'a>]) -> Self {
98        Self { migrations }
99    }
100
101    /// Applies all pending migrations to bring the database to the latest version.
102    /// Each migration is applied in its own transaction.
103    /// Returns the number of applied migrations.
104    /// Uses the turso/sqlite table user_version to track the current version.
105    pub async fn to_latest(&self, conn: &mut Connection) -> Result<usize, (usize, MigrationError)> {
106        let current_version = match get_user_version(conn).await {
107            Ok(v) => v,
108            Err(e) => return Err((0, e.into())),
109        };
110        let target_version = self.migrations.len() as i32;
111        if current_version == target_version {
112            return Ok(0);
113        }
114        if current_version < 0 || current_version > target_version {
115            return Err((
116                0,
117                MigrationError::InvalidUserVersion(current_version, self.migrations.len()),
118            ));
119        }
120
121        let mut applied_count = 0;
122        for (i, migration) in self
123            .migrations
124            .iter()
125            .enumerate()
126            .skip(current_version as usize)
127        {
128            let version = (i + 1) as i32;
129
130            // Start a transaction for this migration
131            let tx = match conn.transaction().await {
132                Ok(tx) => tx,
133                Err(e) => return Err((applied_count, e.into())),
134            };
135
136            // Apply the migration logic
137            let result = match migration {
138                Migration::Sql { name: _name, sql } => tx.execute_batch(sql).await.map(|_| ()),
139                Migration::Fn { name: _name, f } => f(&tx).await,
140            };
141
142            // Capture the name for logging, ensuring we use the one bound in the match or constructing it
143            #[cfg(feature = "tracing")]
144            let migration_name_log = match migration {
145                Migration::Sql { name, .. } => *name,
146                Migration::Fn { name, .. } => *name,
147            };
148
149            if let Err(e) = result {
150                #[cfg(feature = "tracing")]
151                tracing::error!(error = ?e, migration_name = %migration_name_log, "Migration failed");
152                return Err((applied_count, e.into()));
153            }
154
155            // Update the user_version within the same transaction
156            if let Err(e) = set_user_version(&tx, version).await {
157                #[cfg(feature = "tracing")]
158                tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to update user_version");
159                return Err((applied_count, e.into()));
160            }
161
162            // Commit the transaction
163            if let Err(e) = tx.commit().await {
164                #[cfg(feature = "tracing")]
165                tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to commit transaction");
166                return Err((applied_count, e.into()));
167            }
168
169            #[cfg(feature = "tracing")]
170            tracing::debug!(migration_name = %migration_name_log, "Migration applied successfully");
171
172            applied_count += 1;
173        }
174
175        Ok(applied_count)
176    }
177
178    /// Helper function to validate migrations by applying them to an in-memory database.
179    /// Returns the connection to the in-memory database if successful.
180    pub async fn run_all_in_memory(&self) -> Result<Connection, MigrationError> {
181        let db = turso::Builder::new_local(":memory:").build().await?;
182        let mut conn = db.connect()?;
183        if let Err((_, e)) = self.to_latest(&mut conn).await {
184            return Err(e);
185        }
186        Ok(conn)
187    }
188}
189
190async fn get_user_version(conn: &Connection) -> Result<i32, Error> {
191    let version = std::sync::atomic::AtomicI32::new(0);
192    conn.pragma_query("user_version", |row| {
193        let v = row.get::<i32>(0).unwrap();
194        version.store(v, std::sync::atomic::Ordering::SeqCst);
195        Ok(())
196    })
197    .await?;
198    Ok(version.load(std::sync::atomic::Ordering::SeqCst))
199}
200
201async fn set_user_version(conn: &Connection, version: i32) -> Result<(), Error> {
202    conn.pragma_update("user_version", version).await?;
203    Ok(())
204}
205
206#[cfg(test)]
207mod test {
208    use super::*;
209    use turso::Builder;
210
211    const MIGRATIONS: Migrations = Migrations::new(&[
212        Migration::up("001", "CREATE TABLE friend(name TEXT NOT NULL);"),
213        up_fn!("002", my_complex_migration),
214    ]);
215
216    // Example of an external migration function
217    async fn my_complex_migration(conn: &Connection) -> turso::Result<()> {
218        conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
219            .await?;
220        conn.execute("INSERT INTO users (name) VALUES ('Alice')", ())
221            .await?;
222        Ok(())
223    }
224
225    async fn get_in_memory_conn() -> Connection {
226        let db = Builder::new_local(":memory:").build().await.unwrap();
227        db.connect().unwrap()
228    }
229
230    #[tokio::test]
231    async fn test_migrations_success() {
232        let mut conn = get_in_memory_conn().await;
233
234        // Apply migrations
235        let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
236        assert_eq!(applied_count, 2);
237
238        // Verify version
239        let version = get_user_version(&conn).await.unwrap();
240        assert_eq!(version, 2);
241
242        // Verify tables created
243        conn.execute("INSERT INTO friend (name) VALUES ('test')", ())
244            .await
245            .unwrap();
246        conn.execute("INSERT INTO users (name) VALUES ('test')", ())
247            .await
248            .unwrap();
249    }
250
251    #[tokio::test]
252    async fn test_idempotency() {
253        let mut conn = get_in_memory_conn().await;
254
255        // First run
256        let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
257        assert_eq!(applied_count, 2);
258
259        // Second run
260        let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
261        assert_eq!(applied_count, 0);
262    }
263
264    #[tokio::test]
265    async fn test_partial_migration() {
266        let mut conn = get_in_memory_conn().await;
267
268        // Manually set version to 1 (pretending migration 1 is already applied)
269        conn.pragma_update("user_version", 1).await.unwrap();
270
271        let count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
272        assert_eq!(count, 1);
273
274        let version = get_user_version(&conn).await.unwrap();
275        assert_eq!(version, 2);
276    }
277
278    #[tokio::test]
279    async fn test_validate_helper() {
280        // Validate should pass without error
281        let conn = MIGRATIONS.run_all_in_memory().await;
282        assert!(conn.is_ok());
283
284        // Verify we can query the db
285        let conn = conn.unwrap();
286        let version = get_user_version(&conn).await.unwrap();
287        assert_eq!(version, 2);
288    }
289
290    #[tokio::test]
291    async fn test_negative_user_version() {
292        let mut conn = get_in_memory_conn().await;
293
294        // Manually set negative version
295        conn.pragma_update("user_version", -5).await.unwrap();
296
297        let result = MIGRATIONS.to_latest(&mut conn).await;
298
299        match result {
300            Err((0, MigrationError::InvalidUserVersion(v, 2))) => assert_eq!(v, -5),
301            _ => panic!("Expected InvalidUserVersion error, got {:?}", result),
302        }
303    }
304
305    #[tokio::test]
306    async fn test_user_version_too_high() {
307        let mut conn = get_in_memory_conn().await;
308
309        // Manually set version higher than migration count (2)
310        conn.pragma_update("user_version", 10).await.unwrap();
311
312        let result = MIGRATIONS.to_latest(&mut conn).await;
313
314        match result {
315            Err((0, MigrationError::InvalidUserVersion(v, max))) => {
316                assert_eq!(v, 10);
317                assert_eq!(max, 2);
318            }
319            _ => panic!("Expected InvalidUserVersion error, got {:?}", result),
320        }
321    }
322
323    #[tokio::test]
324    async fn test_failing_migration() {
325        let mut conn = get_in_memory_conn().await;
326
327        const BROKEN_MIGRATIONS: Migrations = Migrations::new(&[
328            Migration::up("001", "CREATE TABLE ok (id int)"),
329            Migration::up("002", "SELECT * FROM non_existent_table"), // This should fail
330        ]);
331
332        let result = BROKEN_MIGRATIONS.to_latest(&mut conn).await;
333        match result {
334            Err((1, MigrationError::Turso(_))) => {} // Expected one successful migration
335            _ => panic!("Expected Turso error, got {:?}", result),
336        }
337
338        // Ensure database version was NOT updated for the failing migration
339        let version = get_user_version(&conn).await.unwrap();
340        assert_eq!(version, 1);
341    }
342
343    #[tokio::test]
344    async fn test_up_file_macro() {
345        let mut conn = get_in_memory_conn().await;
346
347        const FILE_MIGRATIONS: Migrations =
348            Migrations::new(&[up_file!("../tests/migration-files/001_test.sql")]);
349
350        let result = FILE_MIGRATIONS.to_latest(&mut conn).await;
351        assert!(result.is_ok(), "File migration failed: {:?}", result.err());
352
353        // Verify version
354        let version = get_user_version(&conn).await.unwrap();
355        assert_eq!(version, 1);
356
357        // Verify tables created
358        conn.execute("INSERT INTO file_test (id) VALUES (1)", ())
359            .await
360            .unwrap();
361
362        conn.execute("INSERT INTO file_test_2 (id) VALUES (1)", ())
363            .await
364            .unwrap();
365    }
366
367    #[tokio::test]
368    async fn test_dynamic_migration_name() {
369        let mut conn = get_in_memory_conn().await;
370
371        let dynamic_name = String::from("003_dynamic");
372        let dynamic_sql = String::from("CREATE TABLE dynamic(id int)");
373
374        let migrations = [Migration::up(&dynamic_name, &dynamic_sql)];
375        let migrations = Migrations::new(&migrations);
376
377        let applied = migrations.to_latest(&mut conn).await.unwrap();
378        assert_eq!(applied, 1);
379
380        let version = get_user_version(&conn).await.unwrap();
381        assert_eq!(version, 1);
382    }
383}