Skip to main content

systemprompt_database/lifecycle/
migrations.rs

1use crate::services::{DatabaseProvider, SqlExecutor};
2use std::collections::HashSet;
3use systemprompt_extension::{Extension, LoaderError, Migration};
4use tracing::{debug, info, warn};
5
6#[derive(Debug, Clone)]
7pub struct AppliedMigration {
8    pub extension_id: String,
9    pub version: u32,
10    pub name: String,
11    pub checksum: String,
12}
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MigrationResult {
16    pub migrations_run: usize,
17    pub migrations_skipped: usize,
18}
19
20pub struct MigrationService<'a> {
21    db: &'a dyn DatabaseProvider,
22}
23
24impl std::fmt::Debug for MigrationService<'_> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("MigrationService").finish_non_exhaustive()
27    }
28}
29
30impl<'a> MigrationService<'a> {
31    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
32        Self { db }
33    }
34
35    pub async fn get_applied_migrations(
36        &self,
37        extension_id: &str,
38    ) -> Result<Vec<AppliedMigration>, LoaderError> {
39        let result = self
40            .db
41            .query_raw_with(
42                &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
43                  extension_id = $1 ORDER BY version",
44                vec![serde_json::Value::String(extension_id.to_string())],
45            )
46            .await
47            .map_err(|e| LoaderError::MigrationFailed {
48                extension: extension_id.to_string(),
49                message: format!("Failed to query applied migrations: {e}"),
50            })?;
51
52        let migrations = result
53            .rows
54            .iter()
55            .filter_map(|row| {
56                Some(AppliedMigration {
57                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
58                    version: row.get("version")?.as_i64()? as u32,
59                    name: row.get("name")?.as_str()?.to_string(),
60                    checksum: row.get("checksum")?.as_str()?.to_string(),
61                })
62            })
63            .collect();
64
65        Ok(migrations)
66    }
67
68    pub async fn run_pending_migrations(
69        &self,
70        extension: &dyn Extension,
71    ) -> Result<MigrationResult, LoaderError> {
72        let ext_id = extension.metadata().id;
73        let migrations = extension.migrations();
74
75        if migrations.is_empty() {
76            return Ok(MigrationResult::default());
77        }
78
79        let applied = self.get_applied_migrations(ext_id).await?;
80        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
81        let applied_checksums: std::collections::HashMap<u32, &str> = applied
82            .iter()
83            .map(|m| (m.version, m.checksum.as_str()))
84            .collect();
85
86        let mut migrations_run = 0;
87        let mut migrations_skipped = 0;
88
89        for migration in &migrations {
90            if applied_versions.contains(&migration.version) {
91                let current_checksum = migration.checksum();
92                if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
93                    if stored_checksum != current_checksum {
94                        warn!(
95                            extension = %ext_id,
96                            version = migration.version,
97                            name = %migration.name,
98                            stored_checksum = %stored_checksum,
99                            current_checksum = %current_checksum,
100                            "Migration checksum mismatch - SQL has changed since it was applied"
101                        );
102                    }
103                }
104                migrations_skipped += 1;
105                debug!(
106                    extension = %ext_id,
107                    version = migration.version,
108                    "Migration already applied, skipping"
109                );
110                continue;
111            }
112
113            self.execute_migration(ext_id, migration).await?;
114            migrations_run += 1;
115        }
116
117        if migrations_run > 0 {
118            info!(
119                extension = %ext_id,
120                migrations_run,
121                migrations_skipped,
122                "Migrations completed"
123            );
124        }
125
126        Ok(MigrationResult {
127            migrations_run,
128            migrations_skipped,
129        })
130    }
131
132    async fn execute_migration(
133        &self,
134        ext_id: &str,
135        migration: &Migration,
136    ) -> Result<(), LoaderError> {
137        info!(
138            extension = %ext_id,
139            version = migration.version,
140            name = %migration.name,
141            "Running migration"
142        );
143
144        SqlExecutor::execute_statements_parsed(self.db, migration.sql)
145            .await
146            .map_err(|e| LoaderError::MigrationFailed {
147                extension: ext_id.to_string(),
148                message: format!(
149                    "Failed to execute migration {} ({}): {e}",
150                    migration.version, migration.name
151                ),
152            })?;
153
154        self.record_migration(ext_id, migration).await?;
155
156        Ok(())
157    }
158
159    async fn record_migration(
160        &self,
161        ext_id: &str,
162        migration: &Migration,
163    ) -> Result<(), LoaderError> {
164        let id = format!("{}_{:03}", ext_id, migration.version);
165        let checksum = migration.checksum();
166        let name = migration.name.replace('\'', "''");
167
168        let sql = format!(
169            "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
170             ('{}', '{}', {}, '{}', '{}')",
171            id, ext_id, migration.version, name, checksum
172        );
173
174        self.db
175            .execute_raw(&sql)
176            .await
177            .map_err(|e| LoaderError::MigrationFailed {
178                extension: ext_id.to_string(),
179                message: format!("Failed to record migration: {e}"),
180            })?;
181
182        Ok(())
183    }
184
185    pub async fn get_migration_status(
186        &self,
187        extension: &dyn Extension,
188    ) -> Result<MigrationStatus, LoaderError> {
189        let ext_id = extension.metadata().id;
190        let defined_migrations = extension.migrations();
191        let applied = self.get_applied_migrations(ext_id).await?;
192
193        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
194
195        let pending: Vec<_> = defined_migrations
196            .iter()
197            .filter(|m| !applied_versions.contains(&m.version))
198            .cloned()
199            .collect();
200
201        Ok(MigrationStatus {
202            extension_id: ext_id.to_string(),
203            total_defined: defined_migrations.len(),
204            total_applied: applied.len(),
205            pending_count: pending.len(),
206            pending,
207            applied,
208        })
209    }
210}
211
212#[derive(Debug)]
213pub struct MigrationStatus {
214    pub extension_id: String,
215    pub total_defined: usize,
216    pub total_applied: usize,
217    pub pending_count: usize,
218    pub pending: Vec<Migration>,
219    pub applied: Vec<AppliedMigration>,
220}