surrealdb_simple_migration/
lib.rs

1extern crate chrono;
2
3use chrono::prelude::*;
4use std::fmt;
5
6use regex::Regex;
7use serde::Deserialize;
8
9use surrealdb::{engine::any::Any, Surreal};
10use tokio::{
11    fs::{read_dir, File},
12    io::AsyncReadExt,
13};
14
15#[derive(Deserialize, PartialEq, Debug, Clone)]
16pub struct Migration {
17    filename: String,
18    created_at: DateTime<Utc>,
19}
20
21#[derive(Debug)]
22pub enum Error {
23    IO(std::io::Error),
24    Surreal(surrealdb::Error),
25    ForbiddenUpdate(String),
26    ForbiddenRemoval(String),
27}
28
29impl From<std::io::Error> for Error {
30    fn from(err: std::io::Error) -> Self {
31        Error::IO(err)
32    }
33}
34
35impl From<surrealdb::Error> for Error {
36    fn from(err: surrealdb::Error) -> Self {
37        Error::Surreal(err)
38    }
39}
40
41impl PartialEq<String> for Migration {
42    fn eq(&self, other: &String) -> bool {
43        self.filename.to_string() == *other
44    }
45}
46
47impl fmt::Display for Error {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match *self {
50            Error::IO(ref err) => write!(f, "IO error: {}", err),
51            Error::Surreal(ref err) => write!(f, "Surreal error: {}", err),
52            Error::ForbiddenUpdate(ref err) => write!(f, "Forbidden update: {}", err),
53            Error::ForbiddenRemoval(ref err) => write!(f, "Forbidden removal: {}", err),
54        }
55    }
56}
57
58impl std::error::Error for Error {
59    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
60        match *self {
61            Error::IO(ref err) => Some(err),
62            Error::Surreal(ref err) => Some(err),
63            Error::ForbiddenUpdate(_) => None,
64            Error::ForbiddenRemoval(_) => None,
65        }
66    }
67}
68
69pub async fn migrate(db: &Surreal<Any>, migration_dir_path: &str) -> Result<(), Error> {
70    setup_migration_table(db).await?;
71    run_migration_files(db, migration_dir_path).await?;
72
73    Ok(())
74}
75
76async fn setup_migration_table(db: &Surreal<Any>) -> Result<(), surrealdb::Error> {
77    let sql = r#"
78        DEFINE TABLE IF NOT EXISTS migrations SCHEMAFULL;
79        DEFINE FIELD IF NOT EXISTS filename ON TABLE migrations TYPE string;
80        DEFINE FIELD IF NOT EXISTS created_at ON TABLE migrations TYPE datetime VALUE time::now();
81    "#;
82
83    let _ = db.query(sql).await?.check()?;
84
85    Ok(())
86}
87
88async fn run_migration_files(db: &Surreal<Any>, migration_dir_path: &str) -> Result<(), Error> {
89    // Get the files already processed.
90    let migrations = db
91        .query("SELECT * FROM migrations ORDER BY created_at ASC;")
92        .await?
93        .check()?
94        .take::<Vec<Migration>>(0)?;
95    let mut remaining_migrations: Vec<Migration> = migrations.clone();
96
97    println!("Migrated files: {:#?}", migrations);
98
99    // Get the surql migration files to execute.
100    let mut dir = read_dir(migration_dir_path).await?;
101    let mut entries: Vec<String> = vec![];
102
103    // Filter the files that fit the migration pattern.
104    while let Some(dir_entry) = dir.next_entry().await? {
105        let filename = dir_entry
106            .path()
107            .to_str()
108            .unwrap()
109            .to_string()
110            .replace((migration_dir_path.to_owned() + "/").as_str(), "");
111        let pattern = r"^[0-9]+[a-zA-Z_0-9]{0,}\.surql$";
112        let regex = Regex::new(&pattern).expect("Failed to build the regexp");
113        if regex.is_match(&filename) {
114            entries.push(filename);
115        }
116    }
117
118    // Sort the entries (by their number prefix).
119    entries.sort(); // TODO: Check how the strings are sorted.
120
121    // Process migration files.
122    println!("Migration files: {:#?}", entries);
123
124    let last_migration = migrations.last();
125
126    // Checker - check for forbidden updates and removals.
127    for entry in entries {
128        // Get the file descriptor.
129        let mut file = File::open(migration_dir_path.to_owned() + "/" + &entry).await?;
130
131        // Check if the file has already been migrated.
132        let migrated = migrations
133            .iter()
134            .any(|migration: &Migration| migration == &entry);
135
136        // If migrated, check that the last update date is anterior to the created_at.
137        if migrated {
138            let updated_at: DateTime<Utc> = File::metadata(&file).await?.modified()?.into();
139
140            // Ensure the file has not been updated after the last migration.
141            if updated_at > last_migration.unwrap().created_at {
142                println!("[X] Forbidden: The migration file '{}' has been updated after the last migration.", entry);
143                return Err(Error::ForbiddenUpdate(format!(
144                    "Forbidden: The migration file '{}' has been updated after the last migration.",
145                    entry
146                )));
147            }
148
149            println!("[V] File already migrated: {}", entry);
150        } else {
151            // TODO: Check that the new migration file appears after the last migration file.
152            let mut migration_content: String = String::new();
153            file.read_to_string(&mut migration_content).await?;
154
155            // When the last migration file is created after the current file, it should fail.
156            if last_migration != None
157                && last_migration.unwrap().created_at
158                    > DateTime::<Utc>::from(File::metadata(&file).await?.modified()?)
159            {
160                println!(
161                    "[X] The migration file '{}' appears before the last migration file '{}'.",
162                    &entry,
163                    last_migration.unwrap().filename
164                );
165
166                return Err(Error::ForbiddenUpdate(format!(
167                    "The migration file '{}' appears before the last migration file '{}'.",
168                    &entry,
169                    last_migration.unwrap().filename
170                )));
171            }
172
173            // Migrate the file.
174            let _ = db.query(migration_content).await?;
175            let _ = db
176                .query("CREATE migrations SET filename=$filename;")
177                .bind(("filename", entry.clone()))
178                .await?
179                .check()?;
180
181            println!("[V] File successfuly migrated: {}", &entry);
182        }
183
184        // Update the migrations list.
185        let position = remaining_migrations
186            .iter()
187            .position(|migration| *migration.filename == entry);
188        if let Some(pos) = position {
189            remaining_migrations.remove(pos);
190        }
191    }
192
193    if remaining_migrations.len() > 0 {
194        println!(
195            "[X] Some migration files are missing - migrations failed: {:?}",
196            remaining_migrations
197        );
198        return Err(Error::ForbiddenRemoval(format!(
199            "Some migration files are missing - migrations failed: {:?}",
200            remaining_migrations
201        )));
202    }
203
204    Ok(())
205}
206
207#[cfg(test)]
208mod tests {
209    use std::fs::create_dir_all;
210
211    use surrealdb::{engine::any, opt::auth::Root};
212    use tokio::{fs::File, io::AsyncWriteExt};
213
214    async fn clean_up() {
215        let db = any::connect("ws://0.0.0.0:8000").await.unwrap();
216
217        db.signin(Root {
218            username: "root",
219            password: "root",
220        })
221        .await
222        .expect("Failed to sign in.");
223
224        db.use_ns("env")
225            .use_db("ssm_test")
226            .await
227            .expect("Failed to use namespace 'env' with database 'dev'.");
228
229        let _ = tokio::fs::remove_dir_all("test/migrations").await;
230        let _ = db
231            .query("DELETE migrations;")
232            .await
233            .expect("Failed to delete migrations table.");
234    }
235
236    #[tokio::test]
237    async fn it_migrates_migration_files() {
238        // Cleanup
239        clean_up().await;
240
241        // Setup database.
242        let db = any::connect("ws://0.0.0.0:8000")
243            .await
244            .expect("Failed to connect to the database.");
245
246        db.signin(Root {
247            username: "root",
248            password: "root",
249        })
250        .await
251        .expect("Failed to sign in.");
252
253        db.use_ns("env")
254            .use_db("ssm_test")
255            .await
256            .expect("Failed to use namespace 'env' with database 'dev'.");
257
258        // 1. When migration files fit the required pattern, it should process them.
259        // Arrange - Create fake migration files.
260        let migration_dir_path = "test/migrations";
261
262        let _ = create_dir_all(migration_dir_path)
263            .expect("Failed to create directory for migration files.");
264        let mut file1 =
265            File::create(migration_dir_path.to_owned() + "/001_create_user_table.surql")
266                .await
267                .unwrap();
268        file1
269            .write_all(
270                b"
271            DEFINE TABLE users SCHEMAFULL;
272            DEFINE FIELD name ON TABLE user TYPE string;
273            DEFINE FIELD email ON TABLE users TYPE string;
274            DEFINE FIELD created_at ON TABLE users TYPE datetime VALUE time::now();
275        ",
276            )
277            .await
278            .unwrap();
279
280        let mut file2 =
281            File::create(migration_dir_path.to_owned() + "/002_create_post_table.surql")
282                .await
283                .unwrap();
284        file2
285            .write_all(
286                b"
287            DEFINE TABLE posts SCHEMAFULL;
288            DEFINE FIELD title ON TABLE posts TYPE string;
289            DEFINE FIELD content ON TABLE posts TYPE string;
290            DEFINE FIELD created_at ON TABLE posts TYPE datetime VALUE time::now();
291        ",
292            )
293            .await
294            .unwrap();
295
296        let mut file3 =
297            File::create(migration_dir_path.to_owned() + "/003_create_comment_table.surql")
298                .await
299                .unwrap();
300        file3
301            .write_all(
302                b"
303            DEFINE TABLE comments SCHEMAFULL;
304            DEFINE FIELD content ON TABLE comments TYPE string;
305            DEFINE FIELD created_at ON TABLE comments TYPE datetime VALUE time::now();
306        ",
307            )
308            .await
309            .unwrap();
310
311        let mut file4 = File::create(migration_dir_path.to_owned() + "/004_i18n_table.surql")
312            .await
313            .unwrap();
314        file4
315            .write_all(
316                b"
317            DEFINE TABLE i18n SCHEMAFULL;
318            DEFINE FIELD locale ON TABLE i18n TYPE string;
319            DEFINE FIELD text ON TABLE i18n TYPE string;
320        ",
321            )
322            .await
323            .unwrap();
324
325        // Act - Run the migration.
326        let result = super::migrate(&db, migration_dir_path).await;
327
328        // Assert
329        assert!(result.is_ok());
330
331        // 2. When migration files are already processed, it should skip them.
332        // Act - Run the migration again.
333        let result = super::migrate(&db, migration_dir_path).await;
334
335        // Assert
336        assert!(result.is_ok());
337
338        // 3. When new migration files are added, it should process them.
339        // Arrange - Add a new migration file.
340        let mut file5 =
341            File::create(migration_dir_path.to_owned() + "/005_create_likes_table.surql")
342                .await
343                .unwrap();
344        file5
345            .write_all(
346                b"
347            DEFINE TABLE likes SCHEMAFULL;
348            DEFINE FIELD user_id ON TABLE likes TYPE record;
349            DEFINE FIELD post_id ON TABLE likes TYPE string;
350            DEFINE FIELD created_at ON TABLE likes TYPE datetime VALUE time::now();
351        ",
352            )
353            .await
354            .unwrap();
355
356        // Act
357        let result = super::migrate(&db, migration_dir_path).await;
358
359        // Assert
360        assert!(result.is_ok());
361
362        // 4. When migration files are updated, it should fail.
363        // Arrange - Update the migration files.
364        file1
365            .write(
366                b"
367            DEFINE FIELD updated_at ON TABLE users TYPE datetime VALUE time::now();
368        ",
369            )
370            .await
371            .unwrap();
372
373        // Act - Run the migration again.
374        let res = super::migrate(&db, migration_dir_path).await;
375
376        // Assert
377        assert!(res.is_err());
378
379        // 5. When a migrated file is removed, it should return an error.
380        // Arrange - Reset the migrations, migrate the files again and remove one file.
381        let _ = db.query("DELETE migrations;").await;
382        super::migrate(&db, migration_dir_path)
383            .await
384            .expect("Failed to migrate the files.");
385        tokio::fs::remove_file(migration_dir_path.to_owned() + "/001_create_user_table.surql")
386            .await
387            .unwrap();
388
389        // Act
390        let res = super::migrate(&db, migration_dir_path).await;
391
392        // Assert
393        assert!(res.is_err());
394
395        // CLEANUP
396        clean_up().await;
397
398        // data cleaning
399        db.query("REMOVE TABLE migrations;").await.unwrap();
400        db.query("REMOVE TABLE users;").await.unwrap();
401        db.query("REMOVE TABLE posts;").await.unwrap();
402        db.query("REMOVE TABLE comments;").await.unwrap();
403        db.query("REMOVE TABLE likes;").await.unwrap();
404    }
405}