Skip to main content

systemprompt_database/lifecycle/migrations/
mod.rs

1//! Extension migration runner backed by the `extension_migrations`
2//! bookkeeping table. [`MigrationService`] applies, reverts, inspects, and
3//! squashes per-extension migration history; reverts live in [`down`],
4//! status/plan queries in [`status`], squash in [`squash`].
5
6mod down;
7mod exec;
8mod mark_applied;
9mod repair;
10mod squash;
11mod status;
12
13pub use mark_applied::MarkAppliedOutcome;
14pub use repair::RepairResult;
15pub use squash::SquashPlan;
16pub use status::{
17    AppliedMigration, ChecksumDrift, ExtensionMigrationStatus, MigrationResult, MigrationStatus,
18    PendingMigration,
19};
20
21use crate::services::{DatabaseProvider, SqlExecutor};
22use exec::{check_cross_extension_alters, execute_statements_transactional};
23use std::collections::HashSet;
24use systemprompt_extension::{Extension, LoaderError, Migration};
25use tracing::{debug, info, warn};
26
27#[derive(Debug, Default, Clone, Copy)]
28pub struct MigrationConfig {
29    pub allow_checksum_drift: bool,
30}
31
32pub struct MigrationService<'a> {
33    db: &'a dyn DatabaseProvider,
34    config: MigrationConfig,
35}
36
37impl std::fmt::Debug for MigrationService<'_> {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("MigrationService")
40            .field("config", &self.config)
41            .finish_non_exhaustive()
42    }
43}
44
45impl<'a> MigrationService<'a> {
46    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
47        Self {
48            db,
49            config: MigrationConfig::default(),
50        }
51    }
52
53    #[must_use]
54    pub const fn with_config(mut self, config: MigrationConfig) -> Self {
55        self.config = config;
56        self
57    }
58
59    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
60        let sql = include_str!("../../../schema/extension_migrations.sql");
61        SqlExecutor::execute_statements_parsed(self.db, sql)
62            .await
63            .map_err(|e| LoaderError::MigrationFailed {
64                extension: "database".to_owned(),
65                message: format!("Failed to ensure migrations table exists: {e}"),
66            })
67    }
68
69    pub async fn get_applied_migrations(
70        &self,
71        extension_id: &str,
72    ) -> Result<Vec<AppliedMigration>, LoaderError> {
73        let result = self
74            .db
75            .query_raw_with(
76                &"SELECT extension_id, version, name, checksum, applied_at FROM \
77                  extension_migrations WHERE extension_id = $1 ORDER BY version",
78                &[&extension_id],
79            )
80            .await
81            .map_err(|e| LoaderError::MigrationFailed {
82                extension: extension_id.to_owned(),
83                message: format!("Failed to query applied migrations: {e}"),
84            })?;
85
86        let migrations = result
87            .rows
88            .iter()
89            .filter_map(|row| {
90                Some(AppliedMigration {
91                    extension_id: row.get("extension_id")?.as_str()?.to_owned(),
92                    version: row.get("version")?.as_i64()? as u32,
93                    name: row.get("name")?.as_str()?.to_owned(),
94                    checksum: row.get("checksum")?.as_str()?.to_owned(),
95                    applied_at: row
96                        .get("applied_at")
97                        .and_then(|v| v.as_str().map(String::from)),
98                })
99            })
100            .collect();
101
102        Ok(migrations)
103    }
104
105    pub async fn run_pending_migrations(
106        &self,
107        extension: &dyn Extension,
108    ) -> Result<MigrationResult, LoaderError> {
109        let ext_id = extension.metadata().id;
110        let migrations = extension.migrations();
111
112        if migrations.is_empty() {
113            return Ok(MigrationResult::default());
114        }
115
116        self.ensure_migrations_table_exists().await?;
117
118        let applied = self.get_applied_migrations(ext_id).await?;
119        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
120        let applied_checksums: std::collections::HashMap<u32, &str> = applied
121            .iter()
122            .map(|m| (m.version, m.checksum.as_str()))
123            .collect();
124
125        let mut migrations_run = 0;
126        let mut migrations_skipped = 0;
127
128        for migration in &migrations {
129            if applied_versions.contains(&migration.version) {
130                self.verify_checksum(
131                    ext_id,
132                    migration,
133                    applied_checksums.get(&migration.version).copied(),
134                )?;
135                migrations_skipped += 1;
136                debug!(
137                    extension = %ext_id,
138                    version = migration.version,
139                    "Migration already applied, skipping"
140                );
141                continue;
142            }
143
144            self.execute_migration(extension, migration).await?;
145            migrations_run += 1;
146        }
147
148        if migrations_run > 0 {
149            info!(
150                extension = %ext_id,
151                migrations_run,
152                migrations_skipped,
153                "Migrations completed"
154            );
155        }
156
157        Ok(MigrationResult {
158            migrations_run,
159            migrations_skipped,
160        })
161    }
162
163    fn verify_checksum(
164        &self,
165        ext_id: &str,
166        migration: &Migration,
167        stored: Option<&str>,
168    ) -> Result<(), LoaderError> {
169        let Some(stored_checksum) = stored else {
170            return Ok(());
171        };
172        let current_checksum = migration.checksum();
173        if stored_checksum == current_checksum {
174            return Ok(());
175        }
176        if self.config.allow_checksum_drift {
177            warn!(
178                extension = %ext_id,
179                version = migration.version,
180                name = %migration.name,
181                stored_checksum = %stored_checksum,
182                current_checksum = %current_checksum,
183                "Migration checksum mismatch tolerated by --allow-checksum-drift"
184            );
185            return Ok(());
186        }
187        Err(LoaderError::MigrationFailed {
188            extension: ext_id.to_owned(),
189            message: format!(
190                "Migration {ver} ('{name}') has been edited since it was applied (stored checksum \
191                 {stored_checksum}, current {current_checksum}). Refusing to proceed. Run \
192                 `systemprompt infra db migrate-repair --apply` to reconcile the tracking table, \
193                 or pass --allow-checksum-drift to bypass the check without fixing it.",
194                ver = migration.version,
195                name = migration.name,
196            ),
197        })
198    }
199
200    async fn execute_migration(
201        &self,
202        extension: &dyn Extension,
203        migration: &Migration,
204    ) -> Result<(), LoaderError> {
205        let ext_id = extension.metadata().id;
206
207        check_cross_extension_alters(extension, migration)?;
208
209        info!(
210            extension = %ext_id,
211            version = migration.version,
212            name = %migration.name,
213            no_transaction = migration.no_transaction,
214            "Running migration"
215        );
216
217        if migration.no_transaction {
218            SqlExecutor::execute_statements_parsed(self.db, migration.sql)
219                .await
220                .map_err(|e| LoaderError::MigrationFailed {
221                    extension: ext_id.to_owned(),
222                    message: format!(
223                        "Failed to execute migration {} ({}): {e}",
224                        migration.version, migration.name
225                    ),
226                })?;
227        } else {
228            let statements = SqlExecutor::parse_sql_statements(migration.sql).map_err(|e| {
229                LoaderError::MigrationFailed {
230                    extension: ext_id.to_owned(),
231                    message: format!(
232                        "Failed to parse migration {} ({}): {e}",
233                        migration.version, migration.name
234                    ),
235                }
236            })?;
237            execute_statements_transactional(self.db, &statements, ext_id, migration).await?;
238        }
239
240        self.record_migration(ext_id, migration).await?;
241
242        Ok(())
243    }
244
245    async fn record_migration(
246        &self,
247        ext_id: &str,
248        migration: &Migration,
249    ) -> Result<(), LoaderError> {
250        let id = format!("{}_{:03}", ext_id, migration.version);
251        let checksum = migration.checksum();
252
253        self.db
254            .execute(
255                &"INSERT INTO extension_migrations (id, extension_id, version, name, checksum) \
256                  VALUES ($1, $2, $3, $4, $5)",
257                &[&id, &ext_id, &migration.version, &migration.name, &checksum],
258            )
259            .await
260            .map_err(|e| LoaderError::MigrationFailed {
261                extension: ext_id.to_owned(),
262                message: format!("Failed to record migration: {e}"),
263            })?;
264
265        Ok(())
266    }
267}