tokio_postgres_migration/
lib.rs

1use tokio_postgres::GenericClient;
2
3pub struct Migration {
4    tablename: String,
5}
6
7impl Migration {
8    pub fn new(tablename: String) -> Self {
9        Self { tablename }
10    }
11
12    async fn execute_script<C: GenericClient>(
13        &self,
14        client: &C,
15        content: &str,
16    ) -> Result<(), tokio_postgres::Error> {
17        let stmt = client.prepare(content).await?;
18        client.execute(&stmt, &[]).await?;
19        Ok(())
20    }
21
22    async fn insert_migration<C: GenericClient>(
23        &self,
24        client: &C,
25        name: &str,
26    ) -> Result<(), tokio_postgres::Error> {
27        let query = format!("INSERT INTO {} (name) VALUES ($1)", self.tablename);
28        let stmt = client.prepare(&query).await?;
29        client.execute(&stmt, &[&name]).await?;
30        Ok(())
31    }
32
33    async fn delete_migration<C: GenericClient>(
34        &self,
35        client: &C,
36        name: &str,
37    ) -> Result<(), tokio_postgres::Error> {
38        let query = format!("DELETE FROM {} WHERE name = $1", self.tablename);
39        let stmt = client.prepare(&query).await?;
40        client.execute(&stmt, &[&name]).await?;
41        Ok(())
42    }
43
44    async fn create_table<C: GenericClient>(
45        &self,
46        client: &C,
47    ) -> Result<(), tokio_postgres::Error> {
48        log::debug!("creating migration table {}", self.tablename);
49        let query = format!(
50            r#"CREATE TABLE IF NOT EXISTS {} ( name TEXT NOT NULL PRIMARY KEY, executed_at TIMESTAMP NOT NULL DEFAULT NOW() )"#,
51            self.tablename
52        );
53        self.execute_script(client, &query).await?;
54        Ok(())
55    }
56
57    async fn exists<C: GenericClient>(
58        &self,
59        client: &C,
60        name: &str,
61    ) -> Result<bool, tokio_postgres::Error> {
62        log::trace!("check if migration {} exists", name);
63        let query = format!("SELECT COUNT(*) FROM {} WHERE name = $1", self.tablename);
64        let stmt = client.prepare(&query).await?;
65        let row = client.query_one(&stmt, &[&name]).await?;
66        let count: i64 = row.get(0);
67
68        Ok(count > 0)
69    }
70
71    /// Migrate all scripts up
72    pub async fn up<C: GenericClient>(
73        &self,
74        client: &mut C,
75        scripts: &[(&str, &str)],
76    ) -> Result<(), tokio_postgres::Error> {
77        log::info!("migrating up to {}", self.tablename);
78        self.create_table(client).await?;
79        for (name, script) in scripts {
80            if !self.exists(client, name).await? {
81                log::debug!("deleting migration {}", name);
82                self.execute_script(client, script).await?;
83                self.insert_migration(client, name).await?;
84            }
85        }
86        Ok(())
87    }
88
89    /// Migrate all scripts down
90    pub async fn down<C: GenericClient>(
91        &self,
92        client: &C,
93        scripts: &[(&str, &str)],
94    ) -> Result<(), tokio_postgres::Error> {
95        log::info!("migrating down to {}", self.tablename);
96        self.create_table(client).await?;
97        for (name, script) in scripts {
98            if self.exists(client, name).await? {
99                log::debug!("deleting migration {}", name);
100                self.execute_script(client, script).await?;
101                self.delete_migration(client, name).await?;
102            }
103        }
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::Migration;
111    use std::str::FromStr;
112
113    const SCRIPTS_UP: [(&str, &str); 2] = [
114        (
115            "0001-create-table-users",
116            include_str!("../assets/0001-create-table-users-up.sql"),
117        ),
118        (
119            "0002-create-table-pets",
120            include_str!("../assets/0002-create-table-pets-up.sql"),
121        ),
122    ];
123
124    const SCRIPTS_DOWN: [(&str, &str); 2] = [
125        (
126            "0002-create-table-pets",
127            include_str!("../assets/0002-create-table-pets-down.sql"),
128        ),
129        (
130            "0001-create-table-users",
131            include_str!("../assets/0001-create-table-users-down.sql"),
132        ),
133    ];
134
135    fn get_url() -> String {
136        std::env::var("POSTGRES_URL").unwrap_or_else(|_| {
137            "postgres://postgres@localhost:5432/postgres?connect_timeout=5".to_string()
138        })
139    }
140
141    fn get_config() -> tokio_postgres::Config {
142        tokio_postgres::Config::from_str(&get_url()).unwrap()
143    }
144
145    async fn get_client() -> tokio_postgres::Client {
146        let cfg = get_config();
147        let (client, con) = cfg.connect(tokio_postgres::NoTls).await.unwrap();
148        tokio::spawn(con);
149        client
150    }
151
152    #[tokio::test]
153    async fn migrating() {
154        let mut client = get_client().await;
155        let migration = Migration::new("table_name".to_string());
156        migration.up(&mut client, &SCRIPTS_UP).await.unwrap();
157        migration.down(&mut client, &SCRIPTS_DOWN).await.unwrap();
158    }
159}