Skip to main content

systemprompt_database/lifecycle/
migrations.rs

1//! Extension migration runner backed by the `extension_migrations`
2//! bookkeeping table.
3
4use crate::services::{DatabaseProvider, SqlExecutor};
5use std::collections::HashSet;
6use systemprompt_extension::{Extension, LoaderError, Migration};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone)]
10pub struct AppliedMigration {
11    pub extension_id: String,
12    pub version: u32,
13    pub name: String,
14    pub checksum: String,
15}
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct MigrationResult {
19    pub migrations_run: usize,
20    pub migrations_skipped: usize,
21}
22
23pub struct MigrationService<'a> {
24    db: &'a dyn DatabaseProvider,
25}
26
27impl std::fmt::Debug for MigrationService<'_> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("MigrationService").finish_non_exhaustive()
30    }
31}
32
33impl<'a> MigrationService<'a> {
34    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
35        Self { db }
36    }
37
38    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
39        let sql = include_str!("../../schema/extension_migrations.sql");
40        SqlExecutor::execute_statements_parsed(self.db, sql)
41            .await
42            .map_err(|e| LoaderError::MigrationFailed {
43                extension: "database".to_string(),
44                message: format!("Failed to ensure migrations table exists: {e}"),
45            })
46    }
47
48    pub async fn get_applied_migrations(
49        &self,
50        extension_id: &str,
51    ) -> Result<Vec<AppliedMigration>, LoaderError> {
52        let result = self
53            .db
54            .query_raw_with(
55                &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
56                  extension_id = $1 ORDER BY version",
57                vec![serde_json::Value::String(extension_id.to_string())],
58            )
59            .await
60            .map_err(|e| LoaderError::MigrationFailed {
61                extension: extension_id.to_string(),
62                message: format!("Failed to query applied migrations: {e}"),
63            })?;
64
65        let migrations = result
66            .rows
67            .iter()
68            .filter_map(|row| {
69                Some(AppliedMigration {
70                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
71                    version: row.get("version")?.as_i64()? as u32,
72                    name: row.get("name")?.as_str()?.to_string(),
73                    checksum: row.get("checksum")?.as_str()?.to_string(),
74                })
75            })
76            .collect();
77
78        Ok(migrations)
79    }
80
81    pub async fn run_pending_migrations(
82        &self,
83        extension: &dyn Extension,
84    ) -> Result<MigrationResult, LoaderError> {
85        let ext_id = extension.metadata().id;
86        let migrations = extension.migrations();
87
88        if migrations.is_empty() {
89            return Ok(MigrationResult::default());
90        }
91
92        self.ensure_migrations_table_exists().await?;
93
94        let applied = self.get_applied_migrations(ext_id).await?;
95        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
96        let applied_checksums: std::collections::HashMap<u32, &str> = applied
97            .iter()
98            .map(|m| (m.version, m.checksum.as_str()))
99            .collect();
100
101        let mut migrations_run = 0;
102        let mut migrations_skipped = 0;
103
104        for migration in &migrations {
105            if applied_versions.contains(&migration.version) {
106                let current_checksum = migration.checksum();
107                if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
108                    if stored_checksum != current_checksum {
109                        warn!(
110                            extension = %ext_id,
111                            version = migration.version,
112                            name = %migration.name,
113                            stored_checksum = %stored_checksum,
114                            current_checksum = %current_checksum,
115                            "Migration checksum mismatch - SQL has changed since it was applied"
116                        );
117                    }
118                }
119                migrations_skipped += 1;
120                debug!(
121                    extension = %ext_id,
122                    version = migration.version,
123                    "Migration already applied, skipping"
124                );
125                continue;
126            }
127
128            self.execute_migration(ext_id, migration).await?;
129            migrations_run += 1;
130        }
131
132        if migrations_run > 0 {
133            info!(
134                extension = %ext_id,
135                migrations_run,
136                migrations_skipped,
137                "Migrations completed"
138            );
139        }
140
141        Ok(MigrationResult {
142            migrations_run,
143            migrations_skipped,
144        })
145    }
146
147    async fn execute_migration(
148        &self,
149        ext_id: &str,
150        migration: &Migration,
151    ) -> Result<(), LoaderError> {
152        info!(
153            extension = %ext_id,
154            version = migration.version,
155            name = %migration.name,
156            "Running migration"
157        );
158
159        SqlExecutor::execute_statements_parsed(self.db, migration.sql)
160            .await
161            .map_err(|e| LoaderError::MigrationFailed {
162                extension: ext_id.to_string(),
163                message: format!(
164                    "Failed to execute migration {} ({}): {e}",
165                    migration.version, migration.name
166                ),
167            })?;
168
169        self.record_migration(ext_id, migration).await?;
170
171        Ok(())
172    }
173
174    async fn record_migration(
175        &self,
176        ext_id: &str,
177        migration: &Migration,
178    ) -> Result<(), LoaderError> {
179        let id = format!("{}_{:03}", ext_id, migration.version);
180        let checksum = migration.checksum();
181        let name = migration.name.replace('\'', "''");
182
183        let sql = format!(
184            "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
185             ('{}', '{}', {}, '{}', '{}')",
186            id, ext_id, migration.version, name, checksum
187        );
188
189        self.db
190            .execute_raw(&sql)
191            .await
192            .map_err(|e| LoaderError::MigrationFailed {
193                extension: ext_id.to_string(),
194                message: format!("Failed to record migration: {e}"),
195            })?;
196
197        Ok(())
198    }
199
200    pub async fn get_migration_status(
201        &self,
202        extension: &dyn Extension,
203    ) -> Result<MigrationStatus, LoaderError> {
204        self.ensure_migrations_table_exists().await?;
205
206        let ext_id = extension.metadata().id;
207        let defined_migrations = extension.migrations();
208        let applied = self.get_applied_migrations(ext_id).await?;
209
210        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
211
212        let pending: Vec<_> = defined_migrations
213            .iter()
214            .filter(|m| !applied_versions.contains(&m.version))
215            .cloned()
216            .collect();
217
218        Ok(MigrationStatus {
219            extension_id: ext_id.to_string(),
220            total_defined: defined_migrations.len(),
221            total_applied: applied.len(),
222            pending_count: pending.len(),
223            pending,
224            applied,
225        })
226    }
227}
228
229#[derive(Debug)]
230pub struct MigrationStatus {
231    pub extension_id: String,
232    pub total_defined: usize,
233    pub total_applied: usize,
234    pub pending_count: usize,
235    pub pending: Vec<Migration>,
236    pub applied: Vec<AppliedMigration>,
237}