Skip to main content

schema_installer/
migrator.rs

1use schema_sql_generator::common::generator_type::GeneratorType;
2use std::collections::HashSet;
3use std::time::Instant;
4
5use crate::config::SchemaInstallerConfig;
6use crate::connection::AnyPool;
7use crate::error::SchemaInstallerError;
8use crate::migration::{compare_versions, compute_checksum, MigrationSource};
9
10pub struct Migrator;
11
12impl Migrator {
13    pub async fn migrate(
14        config: &SchemaInstallerConfig,
15        source: Box<dyn MigrationSource>,
16    ) -> Result<(), SchemaInstallerError> {
17        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
18
19        pool.ensure_migration_table(&config.database_type)
20            .await?;
21
22        let applied = pool.get_applied_migrations().await?;
23        let applied_versions: HashSet<String> = applied
24            .iter()
25            .filter(|m| m.status == "success")
26            .map(|m| m.version.clone())
27            .collect();
28
29        let source_migrations = source.migrations()?;
30
31        for applied_migration in &applied {
32            if applied_migration.status == "success" {
33                if let Some(source_migration) = source_migrations
34                    .iter()
35                    .find(|m| m.version == applied_migration.version)
36                {
37                    let checksum = compute_checksum(&source_migration.sql);
38                    if checksum != applied_migration.checksum {
39                        return Err(SchemaInstallerError::ChecksumMismatch {
40                            version: applied_migration.version.clone(),
41                            expected: applied_migration.checksum.clone(),
42                            found: checksum,
43                        });
44                    }
45                }
46            }
47        }
48
49        let mut migrations = source_migrations;
50        migrations.retain(|m| !applied_versions.contains(&m.version));
51
52        if migrations.is_empty() {
53            println!("No pending migrations to apply");
54            return Ok(());
55        }
56
57        let tool_version = env!("CARGO_PKG_VERSION");
58
59        for migration in migrations {
60            let checksum = compute_checksum(&migration.sql);
61            let migration_id = pool
62                .insert_migration(
63                    &migration.version,
64                    &migration.script_path,
65                    &checksum,
66                    0,
67                    "pending",
68                    tool_version,
69                )
70                .await?;
71
72            let start = Instant::now();
73            match execute_migration(&pool, &config.database_type, &migration.sql).await {
74                Ok(_) => {
75                    let elapsed_ms = start.elapsed().as_millis() as i64;
76                    pool.update_migration_status(migration_id, "success", elapsed_ms)
77                        .await?;
78                    println!(
79                        "Applied migration: {} - {}",
80                        migration.version, migration.description
81                    );
82                }
83                Err(e) => {
84                    let elapsed_ms = start.elapsed().as_millis() as i64;
85                    pool.update_migration_status(migration_id, "failed", elapsed_ms)
86                        .await?;
87                    return Err(SchemaInstallerError::MigrationFailed {
88                        version: migration.version,
89                        error: e.to_string(),
90                    });
91                }
92            }
93        }
94
95        Ok(())
96    }
97
98    pub async fn info(
99        config: &SchemaInstallerConfig,
100        source: Box<dyn MigrationSource>,
101    ) -> Result<(), SchemaInstallerError> {
102        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
103
104        if let Err(_) = pool.ensure_migration_table(&config.database_type).await {
105        }
106
107        let applied = pool.get_applied_migrations().await.unwrap_or_default();
108        let source_migrations = source.migrations()?;
109
110        if applied.is_empty() && source_migrations.is_empty() {
111            println!("No migrations found");
112            return Ok(());
113        }
114
115        println!(
116            "{:<10} {:<30} {:<10} {:<30} {:<15}",
117            "Version", "Description", "Status", "Installed At", "Execution (ms)"
118        );
119        println!("{}", "-".repeat(95));
120
121        let mut all_versions: Vec<String> = applied.iter().map(|m| m.version.clone()).collect();
122        for migration in &source_migrations {
123            if !all_versions.contains(&migration.version) {
124                all_versions.push(migration.version.clone());
125            }
126        }
127
128        all_versions.sort_by(|a, b| compare_versions(a, b));
129
130        for version in all_versions {
131            if let Some(applied_mig) = applied.iter().find(|m| m.version == version) {
132                println!(
133                    "{:<10} {:<30} {:<10} {:<30} {:<15}",
134                    applied_mig.version,
135                    applied_mig.script_path.split('/').last().unwrap_or(""),
136                    applied_mig.status,
137                    applied_mig.installed_at,
138                    applied_mig.execution_time_ms
139                );
140            } else if let Some(source_mig) = source_migrations.iter().find(|m| m.version == version) {
141                println!(
142                    "{:<10} {:<30} {:<10} {:<30} {:<15}",
143                    version,
144                    source_mig.description,
145                    "Pending",
146                    "-",
147                    "-"
148                );
149            }
150        }
151
152        Ok(())
153    }
154
155    pub async fn validate(
156        config: &SchemaInstallerConfig,
157        source: Box<dyn MigrationSource>,
158    ) -> Result<(), SchemaInstallerError> {
159        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
160
161        pool.ensure_migration_table(&config.database_type)
162            .await?;
163
164        let applied = pool.get_applied_migrations().await?;
165        let source_migrations = source.migrations()?;
166
167        let mut mismatches = Vec::new();
168
169        for applied_migration in applied {
170            if applied_migration.status != "success" {
171                continue;
172            }
173
174            if let Some(source_migration) = source_migrations
175                .iter()
176                .find(|m| m.version == applied_migration.version)
177            {
178                let checksum = compute_checksum(&source_migration.sql);
179                if checksum != applied_migration.checksum {
180                    mismatches.push((
181                        applied_migration.version.clone(),
182                        applied_migration.checksum.clone(),
183                        checksum,
184                    ));
185                }
186            }
187        }
188
189        if !mismatches.is_empty() {
190            for (version, expected, found) in mismatches {
191                eprintln!(
192                    "Checksum mismatch for version {}: expected {}, found {}",
193                    version, expected, found
194                );
195            }
196            return Err(SchemaInstallerError::ChecksumMismatch {
197                version: "unknown".to_string(),
198                expected: "see above".to_string(),
199                found: "see above".to_string(),
200            });
201        }
202
203        println!("All migrations validated successfully");
204        Ok(())
205    }
206
207    pub async fn has_pending_migrations(
208        config: &SchemaInstallerConfig,
209        source: Box<dyn MigrationSource>,
210    ) -> Result<bool, SchemaInstallerError> {
211        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
212
213        if pool.ensure_migration_table(&config.database_type).await.is_err() {
214            return Ok(true);
215        }
216
217        let applied = pool.get_applied_migrations().await.unwrap_or_default();
218        let applied_versions: HashSet<String> = applied
219            .iter()
220            .filter(|m| m.status == "success")
221            .map(|m| m.version.clone())
222            .collect();
223
224        let source_migrations = source.migrations()?;
225        let pending = source_migrations
226            .iter()
227            .any(|m| !applied_versions.contains(&m.version));
228
229        Ok(pending)
230    }
231
232    pub async fn repair(
233        config: &SchemaInstallerConfig,
234        source: Box<dyn MigrationSource>,
235    ) -> Result<(), SchemaInstallerError> {
236        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
237
238        pool.delete_failed_migrations().await?;
239        println!("Deleted failed migrations");
240
241        let applied = pool.get_applied_migrations().await?;
242        let source_migrations = source.migrations()?;
243
244        for applied_migration in applied {
245            if applied_migration.status != "success" {
246                continue;
247            }
248
249            if let Some(source_migration) = source_migrations
250                .iter()
251                .find(|m| m.version == applied_migration.version)
252            {
253                let checksum = compute_checksum(&source_migration.sql);
254                if checksum != applied_migration.checksum {
255                    pool.update_migration_checksum(applied_migration.id, &checksum)
256                        .await?;
257                    println!(
258                        "Updated checksum for migration: {}",
259                        applied_migration.version
260                    );
261                }
262            }
263        }
264
265        Ok(())
266    }
267}
268
269async fn execute_migration(
270    pool: &AnyPool,
271    database_type: &GeneratorType,
272    sql: &str,
273) -> Result<(), SchemaInstallerError> {
274    let statements = split_sql_statements(sql, database_type);
275
276    for statement in statements {
277        let trimmed = statement.trim();
278        if !trimmed.is_empty() {
279            pool.execute_sql(trimmed).await?;
280        }
281    }
282
283    Ok(())
284}
285
286fn split_sql_statements(sql: &str, database_type: &GeneratorType) -> Vec<String> {
287    let separator = match database_type {
288        GeneratorType::SqlServer => "GO",
289        _ => ";",
290    };
291
292    sql.split(separator)
293        .map(|s| s.to_string())
294        .collect()
295}