Skip to main content

systemprompt_database/lifecycle/
migrations.rs

1//! Extension migration runner backed by the `extension_migrations`
2//! bookkeeping table.
3
4use crate::services::{DatabaseProvider, SqlExecutor};
5use std::collections::HashSet;
6use systemprompt_extension::{Extension, LoaderError, Migration};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone)]
10pub struct AppliedMigration {
11    pub extension_id: String,
12    pub version: u32,
13    pub name: String,
14    pub checksum: String,
15}
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct MigrationResult {
19    pub migrations_run: usize,
20    pub migrations_skipped: usize,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct MigrationConfig {
25    pub allow_checksum_drift: bool,
26}
27
28pub struct MigrationService<'a> {
29    db: &'a dyn DatabaseProvider,
30    config: MigrationConfig,
31}
32
33impl std::fmt::Debug for MigrationService<'_> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("MigrationService")
36            .field("config", &self.config)
37            .finish_non_exhaustive()
38    }
39}
40
41impl<'a> MigrationService<'a> {
42    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
43        Self {
44            db,
45            config: MigrationConfig::default(),
46        }
47    }
48
49    #[must_use]
50    pub const fn with_config(mut self, config: MigrationConfig) -> Self {
51        self.config = config;
52        self
53    }
54
55    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
56        let sql = include_str!("../../schema/extension_migrations.sql");
57        SqlExecutor::execute_statements_parsed(self.db, sql)
58            .await
59            .map_err(|e| LoaderError::MigrationFailed {
60                extension: "database".to_string(),
61                message: format!("Failed to ensure migrations table exists: {e}"),
62            })
63    }
64
65    pub async fn get_applied_migrations(
66        &self,
67        extension_id: &str,
68    ) -> Result<Vec<AppliedMigration>, LoaderError> {
69        let result = self
70            .db
71            .query_raw_with(
72                &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
73                  extension_id = $1 ORDER BY version",
74                vec![serde_json::Value::String(extension_id.to_string())],
75            )
76            .await
77            .map_err(|e| LoaderError::MigrationFailed {
78                extension: extension_id.to_string(),
79                message: format!("Failed to query applied migrations: {e}"),
80            })?;
81
82        let migrations = result
83            .rows
84            .iter()
85            .filter_map(|row| {
86                Some(AppliedMigration {
87                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
88                    version: row.get("version")?.as_i64()? as u32,
89                    name: row.get("name")?.as_str()?.to_string(),
90                    checksum: row.get("checksum")?.as_str()?.to_string(),
91                })
92            })
93            .collect();
94
95        Ok(migrations)
96    }
97
98    pub async fn run_pending_migrations(
99        &self,
100        extension: &dyn Extension,
101    ) -> Result<MigrationResult, LoaderError> {
102        let ext_id = extension.metadata().id;
103        let migrations = extension.migrations();
104
105        if migrations.is_empty() {
106            return Ok(MigrationResult::default());
107        }
108
109        self.ensure_migrations_table_exists().await?;
110
111        let applied = self.get_applied_migrations(ext_id).await?;
112        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
113        let applied_checksums: std::collections::HashMap<u32, &str> = applied
114            .iter()
115            .map(|m| (m.version, m.checksum.as_str()))
116            .collect();
117
118        let mut migrations_run = 0;
119        let mut migrations_skipped = 0;
120
121        for migration in &migrations {
122            if applied_versions.contains(&migration.version) {
123                let current_checksum = migration.checksum();
124                if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
125                    if stored_checksum != current_checksum {
126                        if self.config.allow_checksum_drift {
127                            warn!(
128                                extension = %ext_id,
129                                version = migration.version,
130                                name = %migration.name,
131                                stored_checksum = %stored_checksum,
132                                current_checksum = %current_checksum,
133                                "Migration checksum mismatch tolerated by --allow-checksum-drift"
134                            );
135                        } else {
136                            return Err(LoaderError::MigrationFailed {
137                                extension: ext_id.to_string(),
138                                message: format!(
139                                    "Migration {ver} ('{name}') has been edited since it was \
140                                     applied (stored checksum {stored_checksum}, current \
141                                     {current_checksum}). Refusing to proceed. Re-run with \
142                                     --allow-checksum-drift to override.",
143                                    ver = migration.version,
144                                    name = migration.name,
145                                ),
146                            });
147                        }
148                    }
149                }
150                migrations_skipped += 1;
151                debug!(
152                    extension = %ext_id,
153                    version = migration.version,
154                    "Migration already applied, skipping"
155                );
156                continue;
157            }
158
159            self.execute_migration(ext_id, migration).await?;
160            migrations_run += 1;
161        }
162
163        if migrations_run > 0 {
164            info!(
165                extension = %ext_id,
166                migrations_run,
167                migrations_skipped,
168                "Migrations completed"
169            );
170        }
171
172        Ok(MigrationResult {
173            migrations_run,
174            migrations_skipped,
175        })
176    }
177
178    async fn execute_migration(
179        &self,
180        ext_id: &str,
181        migration: &Migration,
182    ) -> Result<(), LoaderError> {
183        info!(
184            extension = %ext_id,
185            version = migration.version,
186            name = %migration.name,
187            "Running migration"
188        );
189
190        SqlExecutor::execute_statements_parsed(self.db, migration.sql)
191            .await
192            .map_err(|e| LoaderError::MigrationFailed {
193                extension: ext_id.to_string(),
194                message: format!(
195                    "Failed to execute migration {} ({}): {e}",
196                    migration.version, migration.name
197                ),
198            })?;
199
200        self.record_migration(ext_id, migration).await?;
201
202        Ok(())
203    }
204
205    async fn record_migration(
206        &self,
207        ext_id: &str,
208        migration: &Migration,
209    ) -> Result<(), LoaderError> {
210        let id = format!("{}_{:03}", ext_id, migration.version);
211        let checksum = migration.checksum();
212        let name = migration.name.replace('\'', "''");
213
214        let sql = format!(
215            "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
216             ('{}', '{}', {}, '{}', '{}')",
217            id, ext_id, migration.version, name, checksum
218        );
219
220        self.db
221            .execute_raw(&sql)
222            .await
223            .map_err(|e| LoaderError::MigrationFailed {
224                extension: ext_id.to_string(),
225                message: format!("Failed to record migration: {e}"),
226            })?;
227
228        Ok(())
229    }
230
231    pub async fn get_migration_status(
232        &self,
233        extension: &dyn Extension,
234    ) -> Result<MigrationStatus, LoaderError> {
235        self.ensure_migrations_table_exists().await?;
236
237        let ext_id = extension.metadata().id;
238        let defined_migrations = extension.migrations();
239        let applied = self.get_applied_migrations(ext_id).await?;
240
241        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
242
243        let pending: Vec<_> = defined_migrations
244            .iter()
245            .filter(|m| !applied_versions.contains(&m.version))
246            .cloned()
247            .collect();
248
249        Ok(MigrationStatus {
250            extension_id: ext_id.to_string(),
251            total_defined: defined_migrations.len(),
252            total_applied: applied.len(),
253            pending_count: pending.len(),
254            pending,
255            applied,
256        })
257    }
258}
259
260#[derive(Debug)]
261pub struct MigrationStatus {
262    pub extension_id: String,
263    pub total_defined: usize,
264    pub total_applied: usize,
265    pub pending_count: usize,
266    pub pending: Vec<Migration>,
267    pub applied: Vec<AppliedMigration>,
268}