Skip to main content

systemprompt_database/lifecycle/migrations/
mod.rs

1//! Extension migration runner backed by the `extension_migrations`
2//! bookkeeping table. [`MigrationService`] applies, reverts, inspects, and
3//! squashes per-extension migration history; reverts live in [`down`],
4//! status/plan queries in [`status`], squash in [`squash`].
5
6mod down;
7mod exec;
8mod squash;
9mod status;
10
11pub use squash::SquashPlan;
12pub use status::{
13    AppliedMigration, ChecksumDrift, ExtensionMigrationStatus, MigrationResult, MigrationStatus,
14    PendingMigration,
15};
16
17use crate::services::{DatabaseProvider, SqlExecutor};
18use exec::{check_cross_extension_alters, execute_statements_transactional};
19use std::collections::HashSet;
20use systemprompt_extension::{Extension, LoaderError, Migration};
21use tracing::{debug, info, warn};
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct MigrationConfig {
25    pub allow_checksum_drift: bool,
26}
27
28pub struct MigrationService<'a> {
29    db: &'a dyn DatabaseProvider,
30    config: MigrationConfig,
31}
32
33impl std::fmt::Debug for MigrationService<'_> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("MigrationService")
36            .field("config", &self.config)
37            .finish_non_exhaustive()
38    }
39}
40
41impl<'a> MigrationService<'a> {
42    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
43        Self {
44            db,
45            config: MigrationConfig::default(),
46        }
47    }
48
49    #[must_use]
50    pub const fn with_config(mut self, config: MigrationConfig) -> Self {
51        self.config = config;
52        self
53    }
54
55    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
56        let sql = include_str!("../../../schema/extension_migrations.sql");
57        SqlExecutor::execute_statements_parsed(self.db, sql)
58            .await
59            .map_err(|e| LoaderError::MigrationFailed {
60                extension: "database".to_string(),
61                message: format!("Failed to ensure migrations table exists: {e}"),
62            })
63    }
64
65    pub async fn get_applied_migrations(
66        &self,
67        extension_id: &str,
68    ) -> Result<Vec<AppliedMigration>, LoaderError> {
69        let result = self
70            .db
71            .query_raw_with(
72                &"SELECT extension_id, version, name, checksum, applied_at FROM \
73                  extension_migrations WHERE extension_id = $1 ORDER BY version",
74                &[&extension_id],
75            )
76            .await
77            .map_err(|e| LoaderError::MigrationFailed {
78                extension: extension_id.to_string(),
79                message: format!("Failed to query applied migrations: {e}"),
80            })?;
81
82        let migrations = result
83            .rows
84            .iter()
85            .filter_map(|row| {
86                Some(AppliedMigration {
87                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
88                    version: row.get("version")?.as_i64()? as u32,
89                    name: row.get("name")?.as_str()?.to_string(),
90                    checksum: row.get("checksum")?.as_str()?.to_string(),
91                    applied_at: row
92                        .get("applied_at")
93                        .and_then(|v| v.as_str().map(String::from)),
94                })
95            })
96            .collect();
97
98        Ok(migrations)
99    }
100
101    pub async fn run_pending_migrations(
102        &self,
103        extension: &dyn Extension,
104    ) -> Result<MigrationResult, LoaderError> {
105        let ext_id = extension.metadata().id;
106        let migrations = extension.migrations();
107
108        if migrations.is_empty() {
109            return Ok(MigrationResult::default());
110        }
111
112        self.ensure_migrations_table_exists().await?;
113
114        let applied = self.get_applied_migrations(ext_id).await?;
115        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
116        let applied_checksums: std::collections::HashMap<u32, &str> = applied
117            .iter()
118            .map(|m| (m.version, m.checksum.as_str()))
119            .collect();
120
121        let mut migrations_run = 0;
122        let mut migrations_skipped = 0;
123
124        for migration in &migrations {
125            if applied_versions.contains(&migration.version) {
126                self.verify_checksum(
127                    ext_id,
128                    migration,
129                    applied_checksums.get(&migration.version).copied(),
130                )?;
131                migrations_skipped += 1;
132                debug!(
133                    extension = %ext_id,
134                    version = migration.version,
135                    "Migration already applied, skipping"
136                );
137                continue;
138            }
139
140            self.execute_migration(extension, migration).await?;
141            migrations_run += 1;
142        }
143
144        if migrations_run > 0 {
145            info!(
146                extension = %ext_id,
147                migrations_run,
148                migrations_skipped,
149                "Migrations completed"
150            );
151        }
152
153        Ok(MigrationResult {
154            migrations_run,
155            migrations_skipped,
156        })
157    }
158
159    fn verify_checksum(
160        &self,
161        ext_id: &str,
162        migration: &Migration,
163        stored: Option<&str>,
164    ) -> Result<(), LoaderError> {
165        let Some(stored_checksum) = stored else {
166            return Ok(());
167        };
168        let current_checksum = migration.checksum();
169        if stored_checksum == current_checksum {
170            return Ok(());
171        }
172        if self.config.allow_checksum_drift {
173            warn!(
174                extension = %ext_id,
175                version = migration.version,
176                name = %migration.name,
177                stored_checksum = %stored_checksum,
178                current_checksum = %current_checksum,
179                "Migration checksum mismatch tolerated by --allow-checksum-drift"
180            );
181            return Ok(());
182        }
183        Err(LoaderError::MigrationFailed {
184            extension: ext_id.to_string(),
185            message: format!(
186                "Migration {ver} ('{name}') has been edited since it was applied (stored checksum \
187                 {stored_checksum}, current {current_checksum}). Refusing to proceed. Re-run with \
188                 --allow-checksum-drift to override.",
189                ver = migration.version,
190                name = migration.name,
191            ),
192        })
193    }
194
195    async fn execute_migration(
196        &self,
197        extension: &dyn Extension,
198        migration: &Migration,
199    ) -> Result<(), LoaderError> {
200        let ext_id = extension.metadata().id;
201
202        check_cross_extension_alters(extension, migration)?;
203
204        info!(
205            extension = %ext_id,
206            version = migration.version,
207            name = %migration.name,
208            no_transaction = migration.no_transaction,
209            "Running migration"
210        );
211
212        if migration.no_transaction {
213            SqlExecutor::execute_statements_parsed(self.db, migration.sql)
214                .await
215                .map_err(|e| LoaderError::MigrationFailed {
216                    extension: ext_id.to_string(),
217                    message: format!(
218                        "Failed to execute migration {} ({}): {e}",
219                        migration.version, migration.name
220                    ),
221                })?;
222        } else {
223            let statements = SqlExecutor::parse_sql_statements(migration.sql).map_err(|e| {
224                LoaderError::MigrationFailed {
225                    extension: ext_id.to_string(),
226                    message: format!(
227                        "Failed to parse migration {} ({}): {e}",
228                        migration.version, migration.name
229                    ),
230                }
231            })?;
232            execute_statements_transactional(self.db, &statements, ext_id, migration).await?;
233        }
234
235        self.record_migration(ext_id, migration).await?;
236
237        Ok(())
238    }
239
240    async fn record_migration(
241        &self,
242        ext_id: &str,
243        migration: &Migration,
244    ) -> Result<(), LoaderError> {
245        let id = format!("{}_{:03}", ext_id, migration.version);
246        let checksum = migration.checksum();
247
248        self.db
249            .execute(
250                &"INSERT INTO extension_migrations (id, extension_id, version, name, checksum) \
251                  VALUES ($1, $2, $3, $4, $5)",
252                &[&id, &ext_id, &migration.version, &migration.name, &checksum],
253            )
254            .await
255            .map_err(|e| LoaderError::MigrationFailed {
256                extension: ext_id.to_string(),
257                message: format!("Failed to record migration: {e}"),
258            })?;
259
260        Ok(())
261    }
262}