silent_db/core/
migrate.rs

1use crate::core::dsl::SqlStatement;
2use crate::core::tables::TableUtil;
3use crate::Table;
4use anyhow::{anyhow, Context, Result};
5use chrono::Utc;
6use console::style;
7use sqlparser::ast::Statement;
8use sqlparser::dialect::GenericDialect;
9use sqlparser::parser::Parser;
10use sqlx::any::AnyConnectOptions;
11use sqlx::migrate::{MigrationType, Migrator};
12use sqlx::{AnyConnection, ConnectOptions, FromRow};
13use std::fs;
14use std::fs::{File, OpenOptions};
15use std::io::Write;
16use std::ops::Deref;
17use std::path::{Path, PathBuf};
18
19pub struct Migrate {
20    pub(crate) migrations_path: String,
21    pub(crate) options: AnyConnectOptions,
22    pub(crate) conn: Option<AnyConnection>,
23}
24
25#[derive(Debug, FromRow)]
26struct TableName(String);
27
28#[derive(Debug, FromRow)]
29struct TableCreate(String, String);
30
31impl Migrate {
32    pub fn new(migrations_path: String, options: AnyConnectOptions) -> Self {
33        Migrate {
34            migrations_path,
35            options,
36            conn: None,
37        }
38    }
39    #[inline]
40    async fn connect(&mut self) -> Result<()> {
41        if self.conn.is_none() {
42            sqlx::any::install_default_drivers();
43            let conn = self.options.connect().await?;
44            self.conn = Some(conn);
45        }
46        Ok(())
47    }
48    #[inline]
49    async fn get_conn(mut self) -> Result<AnyConnection> {
50        if self.conn.is_none() {
51            self.connect().await?;
52        }
53        self.conn.ok_or(anyhow!("no connection"))
54    }
55    #[inline]
56    async fn get_migrator(&self) -> Result<Migrator> {
57        Migrator::new(Path::new(&self.migrations_path))
58            .await
59            .map_err(|e| anyhow!("{}", e))
60    }
61    pub async fn make_migration(&mut self, tables: Vec<impl Table>) -> Result<()> {
62        println!("run migrate");
63        // todo!()
64        let _ = tables;
65        for table in tables {
66            table.get_create_sql();
67        }
68        Ok(())
69    }
70    pub async fn migrate(self) -> Result<()> {
71        println!("run migrate");
72        let migrator = self.get_migrator().await?;
73        let mut conn = self.get_conn().await?;
74        migrator
75            .run(&mut conn)
76            .await
77            .map_err(|e| anyhow!("migrate error: {}", e))
78    }
79
80    pub async fn rollback(self, target: i64) -> Result<()> {
81        println!("run rollback");
82        let migrator = self.get_migrator().await?;
83        let mut conn = self.get_conn().await?;
84        migrator.undo(&mut conn, target).await?;
85        Ok(())
86    }
87    async fn get_exist_tables(self, utils: &dyn TableUtil) -> Result<Vec<SqlStatement>> {
88        let mut conn = self.get_conn().await?;
89        if conn.backend_name() != utils.get_name() {
90            return Err(anyhow!(
91                "database({}) is not support, database({}) is supported",
92                conn.backend_name(),
93                utils.get_name()
94            ));
95        }
96        let sql = utils.get_all_tables();
97        let tables: Vec<TableName> = sqlx::query_as(&sql)
98            .fetch_all(&mut conn)
99            .await
100            .map_err(|e| anyhow!("{}", e))?;
101        let mut generate_tables: Vec<SqlStatement> = vec![];
102        for table in tables {
103            let create_list: Vec<TableCreate> = sqlx::query_as(&utils.get_table(&table.0))
104                .fetch_all(&mut conn)
105                .await?;
106            let dialect = GenericDialect {};
107            for create in create_list {
108                let sql = create.1;
109                let ast = Parser::parse_sql(&dialect, &sql)?
110                    .pop()
111                    .ok_or(anyhow!("failed to parse sql"))?;
112                if let Statement::CreateTable { name, .. } = ast.clone() {
113                    if name.to_string() == "`_sqlx_migrations`" {
114                        continue;
115                    }
116                }
117                generate_tables.push((ast, sql.to_string()).into());
118            }
119        }
120        Ok(generate_tables)
121    }
122    async fn generate_files(
123        self,
124        utils: Box<dyn TableUtil>,
125        up_path_buf: PathBuf,
126        down_path_buf: PathBuf,
127    ) -> Result<()> {
128        let mut generate_tables = self.get_exist_tables(utils.deref()).await?;
129        let mut up_file = OpenOptions::new().append(true).open(&up_path_buf)?;
130        let mut down_file = OpenOptions::new().append(true).open(&down_path_buf)?;
131        generate_tables.sort();
132        for sql in &generate_tables {
133            up_file.write_all(sql.1.as_bytes())?;
134            up_file.write_all(b";\n")?;
135        }
136        generate_tables.reverse();
137        for sql in &generate_tables {
138            let table = utils.transform(sql)?;
139            down_file.write_all(table.get_drop_sql().as_bytes())?;
140            down_file.write_all(b"\n")?;
141        }
142        let models_path = Path::new("./src/models");
143        utils.generate_models(generate_tables, models_path)?;
144        println!(
145            "Generate models at {:?} success",
146            style(models_path).green()
147        );
148        Ok(())
149    }
150    pub async fn generate(self, utils: Box<dyn TableUtil>) -> Result<()> {
151        println!("run generate");
152        let migrations_path = self.migrations_path.clone();
153        let path = Path::new(&migrations_path);
154        if !path.exists() {
155            fs::create_dir_all(path)?;
156        }
157        if !path.is_dir() {
158            return Err(anyhow!("migrations path is not a directory"));
159        }
160        // TODO: check if the migrations path is empty
161        // if path.read_dir()?.next().is_some() {
162        //     return Err(anyhow!("migrations path is not empty"));
163        // }
164        let prefix = Utc::now().format("%Y%m%d%H%M%S").to_string();
165        let up_path_buf = create_file(
166            &migrations_path,
167            &prefix,
168            "init",
169            MigrationType::ReversibleUp,
170        )?;
171        let down_path_buf = create_file(
172            &migrations_path,
173            &prefix,
174            "init",
175            MigrationType::ReversibleDown,
176        )?;
177        match self
178            .generate_files(utils, up_path_buf.clone(), down_path_buf.clone())
179            .await
180        {
181            Ok(_) => {
182                println!(
183                    "Migration files {} generated at ./{}",
184                    style(prefix).yellow(),
185                    style(migrations_path).green()
186                );
187                Ok(())
188            }
189            Err(e) => {
190                fs::remove_file(up_path_buf)?;
191                fs::remove_file(down_path_buf)?;
192                println!(
193                    "Migration files {} generate failed by {}",
194                    style(prefix).yellow(),
195                    style(e.to_string()).red()
196                );
197                Err(e)
198            }
199        }
200    }
201}
202
203fn create_file(
204    migration_source: &str,
205    file_prefix: &str,
206    description: &str,
207    migration_type: MigrationType,
208) -> Result<PathBuf> {
209    let mut file_name = file_prefix.to_string();
210    file_name.push('_');
211    file_name.push_str(&description.replace(' ', "_"));
212    file_name.push_str(migration_type.suffix());
213
214    let mut path = PathBuf::new();
215    path.push(migration_source);
216    path.push(&file_name);
217
218    println!("Creating {}", style(path.display()).cyan());
219
220    let mut file = File::create(&path).context("Failed to create migration file")?;
221
222    Write::write_all(&mut file, migration_type.file_content().as_bytes())?;
223
224    Ok(path)
225}