Skip to main content

systemprompt_database/lifecycle/migrations/
mod.rs

1//! Extension migration runner backed by the `extension_migrations`
2//! bookkeeping table. [`MigrationService`] applies, reverts, inspects, and
3//! squashes per-extension migration history; reverts live in [`down`],
4//! status/plan queries in [`status`], squash in [`squash`].
5
6mod down;
7mod exec;
8mod repair;
9mod squash;
10mod status;
11
12pub use repair::RepairResult;
13pub use squash::SquashPlan;
14pub use status::{
15    AppliedMigration, ChecksumDrift, ExtensionMigrationStatus, MigrationResult, MigrationStatus,
16    PendingMigration,
17};
18
19use crate::services::{DatabaseProvider, SqlExecutor};
20use exec::{check_cross_extension_alters, execute_statements_transactional};
21use std::collections::HashSet;
22use systemprompt_extension::{Extension, LoaderError, Migration};
23use tracing::{debug, info, warn};
24
25#[derive(Debug, Default, Clone, Copy)]
26pub struct MigrationConfig {
27    pub allow_checksum_drift: bool,
28}
29
30pub struct MigrationService<'a> {
31    db: &'a dyn DatabaseProvider,
32    config: MigrationConfig,
33}
34
35impl std::fmt::Debug for MigrationService<'_> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("MigrationService")
38            .field("config", &self.config)
39            .finish_non_exhaustive()
40    }
41}
42
43impl<'a> MigrationService<'a> {
44    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
45        Self {
46            db,
47            config: MigrationConfig::default(),
48        }
49    }
50
51    #[must_use]
52    pub const fn with_config(mut self, config: MigrationConfig) -> Self {
53        self.config = config;
54        self
55    }
56
57    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
58        let sql = include_str!("../../../schema/extension_migrations.sql");
59        SqlExecutor::execute_statements_parsed(self.db, sql)
60            .await
61            .map_err(|e| LoaderError::MigrationFailed {
62                extension: "database".to_string(),
63                message: format!("Failed to ensure migrations table exists: {e}"),
64            })
65    }
66
67    pub async fn get_applied_migrations(
68        &self,
69        extension_id: &str,
70    ) -> Result<Vec<AppliedMigration>, LoaderError> {
71        let result = self
72            .db
73            .query_raw_with(
74                &"SELECT extension_id, version, name, checksum, applied_at FROM \
75                  extension_migrations WHERE extension_id = $1 ORDER BY version",
76                &[&extension_id],
77            )
78            .await
79            .map_err(|e| LoaderError::MigrationFailed {
80                extension: extension_id.to_string(),
81                message: format!("Failed to query applied migrations: {e}"),
82            })?;
83
84        let migrations = result
85            .rows
86            .iter()
87            .filter_map(|row| {
88                Some(AppliedMigration {
89                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
90                    version: row.get("version")?.as_i64()? as u32,
91                    name: row.get("name")?.as_str()?.to_string(),
92                    checksum: row.get("checksum")?.as_str()?.to_string(),
93                    applied_at: row
94                        .get("applied_at")
95                        .and_then(|v| v.as_str().map(String::from)),
96                })
97            })
98            .collect();
99
100        Ok(migrations)
101    }
102
103    pub async fn run_pending_migrations(
104        &self,
105        extension: &dyn Extension,
106    ) -> Result<MigrationResult, LoaderError> {
107        let ext_id = extension.metadata().id;
108        let migrations = extension.migrations();
109
110        if migrations.is_empty() {
111            return Ok(MigrationResult::default());
112        }
113
114        self.ensure_migrations_table_exists().await?;
115
116        let applied = self.get_applied_migrations(ext_id).await?;
117        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
118        let applied_checksums: std::collections::HashMap<u32, &str> = applied
119            .iter()
120            .map(|m| (m.version, m.checksum.as_str()))
121            .collect();
122
123        let mut migrations_run = 0;
124        let mut migrations_skipped = 0;
125
126        for migration in &migrations {
127            if applied_versions.contains(&migration.version) {
128                self.verify_checksum(
129                    ext_id,
130                    migration,
131                    applied_checksums.get(&migration.version).copied(),
132                )?;
133                migrations_skipped += 1;
134                debug!(
135                    extension = %ext_id,
136                    version = migration.version,
137                    "Migration already applied, skipping"
138                );
139                continue;
140            }
141
142            self.execute_migration(extension, migration).await?;
143            migrations_run += 1;
144        }
145
146        if migrations_run > 0 {
147            info!(
148                extension = %ext_id,
149                migrations_run,
150                migrations_skipped,
151                "Migrations completed"
152            );
153        }
154
155        Ok(MigrationResult {
156            migrations_run,
157            migrations_skipped,
158        })
159    }
160
161    fn verify_checksum(
162        &self,
163        ext_id: &str,
164        migration: &Migration,
165        stored: Option<&str>,
166    ) -> Result<(), LoaderError> {
167        let Some(stored_checksum) = stored else {
168            return Ok(());
169        };
170        let current_checksum = migration.checksum();
171        if stored_checksum == current_checksum {
172            return Ok(());
173        }
174        if self.config.allow_checksum_drift {
175            warn!(
176                extension = %ext_id,
177                version = migration.version,
178                name = %migration.name,
179                stored_checksum = %stored_checksum,
180                current_checksum = %current_checksum,
181                "Migration checksum mismatch tolerated by --allow-checksum-drift"
182            );
183            return Ok(());
184        }
185        Err(LoaderError::MigrationFailed {
186            extension: ext_id.to_string(),
187            message: format!(
188                "Migration {ver} ('{name}') has been edited since it was applied (stored checksum \
189                 {stored_checksum}, current {current_checksum}). Refusing to proceed. Run \
190                 `systemprompt infra db migrate-repair --apply` to reconcile the tracking table, \
191                 or pass --allow-checksum-drift to bypass the check without fixing it.",
192                ver = migration.version,
193                name = migration.name,
194            ),
195        })
196    }
197
198    async fn execute_migration(
199        &self,
200        extension: &dyn Extension,
201        migration: &Migration,
202    ) -> Result<(), LoaderError> {
203        let ext_id = extension.metadata().id;
204
205        check_cross_extension_alters(extension, migration)?;
206
207        info!(
208            extension = %ext_id,
209            version = migration.version,
210            name = %migration.name,
211            no_transaction = migration.no_transaction,
212            "Running migration"
213        );
214
215        if migration.no_transaction {
216            SqlExecutor::execute_statements_parsed(self.db, migration.sql)
217                .await
218                .map_err(|e| LoaderError::MigrationFailed {
219                    extension: ext_id.to_string(),
220                    message: format!(
221                        "Failed to execute migration {} ({}): {e}",
222                        migration.version, migration.name
223                    ),
224                })?;
225        } else {
226            let statements = SqlExecutor::parse_sql_statements(migration.sql).map_err(|e| {
227                LoaderError::MigrationFailed {
228                    extension: ext_id.to_string(),
229                    message: format!(
230                        "Failed to parse migration {} ({}): {e}",
231                        migration.version, migration.name
232                    ),
233                }
234            })?;
235            execute_statements_transactional(self.db, &statements, ext_id, migration).await?;
236        }
237
238        self.record_migration(ext_id, migration).await?;
239
240        Ok(())
241    }
242
243    async fn record_migration(
244        &self,
245        ext_id: &str,
246        migration: &Migration,
247    ) -> Result<(), LoaderError> {
248        let id = format!("{}_{:03}", ext_id, migration.version);
249        let checksum = migration.checksum();
250
251        self.db
252            .execute(
253                &"INSERT INTO extension_migrations (id, extension_id, version, name, checksum) \
254                  VALUES ($1, $2, $3, $4, $5)",
255                &[&id, &ext_id, &migration.version, &migration.name, &checksum],
256            )
257            .await
258            .map_err(|e| LoaderError::MigrationFailed {
259                extension: ext_id.to_string(),
260                message: format!("Failed to record migration: {e}"),
261            })?;
262
263        Ok(())
264    }
265}