systemprompt_database/lifecycle/
migrations.rs1use crate::services::{DatabaseProvider, SqlExecutor};
2use std::collections::HashSet;
3use systemprompt_extension::{Extension, LoaderError, Migration};
4use tracing::{debug, info, warn};
5
6#[derive(Debug, Clone)]
7pub struct AppliedMigration {
8 pub extension_id: String,
9 pub version: u32,
10 pub name: String,
11 pub checksum: String,
12}
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MigrationResult {
16 pub migrations_run: usize,
17 pub migrations_skipped: usize,
18}
19
20pub struct MigrationService<'a> {
21 db: &'a dyn DatabaseProvider,
22}
23
24impl std::fmt::Debug for MigrationService<'_> {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("MigrationService").finish_non_exhaustive()
27 }
28}
29
30impl<'a> MigrationService<'a> {
31 pub fn new(db: &'a dyn DatabaseProvider) -> Self {
32 Self { db }
33 }
34
35 pub async fn get_applied_migrations(
36 &self,
37 extension_id: &str,
38 ) -> Result<Vec<AppliedMigration>, LoaderError> {
39 let result = self
40 .db
41 .query_raw_with(
42 &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
43 extension_id = $1 ORDER BY version",
44 vec![serde_json::Value::String(extension_id.to_string())],
45 )
46 .await
47 .map_err(|e| LoaderError::MigrationFailed {
48 extension: extension_id.to_string(),
49 message: format!("Failed to query applied migrations: {e}"),
50 })?;
51
52 let migrations = result
53 .rows
54 .iter()
55 .filter_map(|row| {
56 Some(AppliedMigration {
57 extension_id: row.get("extension_id")?.as_str()?.to_string(),
58 version: row.get("version")?.as_i64()? as u32,
59 name: row.get("name")?.as_str()?.to_string(),
60 checksum: row.get("checksum")?.as_str()?.to_string(),
61 })
62 })
63 .collect();
64
65 Ok(migrations)
66 }
67
68 pub async fn run_pending_migrations(
69 &self,
70 extension: &dyn Extension,
71 ) -> Result<MigrationResult, LoaderError> {
72 let ext_id = extension.metadata().id;
73 let migrations = extension.migrations();
74
75 if migrations.is_empty() {
76 return Ok(MigrationResult::default());
77 }
78
79 let applied = self.get_applied_migrations(ext_id).await?;
80 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
81 let applied_checksums: std::collections::HashMap<u32, &str> = applied
82 .iter()
83 .map(|m| (m.version, m.checksum.as_str()))
84 .collect();
85
86 let mut migrations_run = 0;
87 let mut migrations_skipped = 0;
88
89 for migration in &migrations {
90 if applied_versions.contains(&migration.version) {
91 let current_checksum = migration.checksum();
92 if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
93 if stored_checksum != current_checksum {
94 warn!(
95 extension = %ext_id,
96 version = migration.version,
97 name = %migration.name,
98 stored_checksum = %stored_checksum,
99 current_checksum = %current_checksum,
100 "Migration checksum mismatch - SQL has changed since it was applied"
101 );
102 }
103 }
104 migrations_skipped += 1;
105 debug!(
106 extension = %ext_id,
107 version = migration.version,
108 "Migration already applied, skipping"
109 );
110 continue;
111 }
112
113 self.execute_migration(ext_id, migration).await?;
114 migrations_run += 1;
115 }
116
117 if migrations_run > 0 {
118 info!(
119 extension = %ext_id,
120 migrations_run,
121 migrations_skipped,
122 "Migrations completed"
123 );
124 }
125
126 Ok(MigrationResult {
127 migrations_run,
128 migrations_skipped,
129 })
130 }
131
132 async fn execute_migration(
133 &self,
134 ext_id: &str,
135 migration: &Migration,
136 ) -> Result<(), LoaderError> {
137 info!(
138 extension = %ext_id,
139 version = migration.version,
140 name = %migration.name,
141 "Running migration"
142 );
143
144 SqlExecutor::execute_statements_parsed(self.db, migration.sql)
145 .await
146 .map_err(|e| LoaderError::MigrationFailed {
147 extension: ext_id.to_string(),
148 message: format!(
149 "Failed to execute migration {} ({}): {e}",
150 migration.version, migration.name
151 ),
152 })?;
153
154 self.record_migration(ext_id, migration).await?;
155
156 Ok(())
157 }
158
159 async fn record_migration(
160 &self,
161 ext_id: &str,
162 migration: &Migration,
163 ) -> Result<(), LoaderError> {
164 let id = format!("{}_{:03}", ext_id, migration.version);
165 let checksum = migration.checksum();
166 let name = migration.name.replace('\'', "''");
167
168 let sql = format!(
169 "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
170 ('{}', '{}', {}, '{}', '{}')",
171 id, ext_id, migration.version, name, checksum
172 );
173
174 self.db
175 .execute_raw(&sql)
176 .await
177 .map_err(|e| LoaderError::MigrationFailed {
178 extension: ext_id.to_string(),
179 message: format!("Failed to record migration: {e}"),
180 })?;
181
182 Ok(())
183 }
184
185 pub async fn get_migration_status(
186 &self,
187 extension: &dyn Extension,
188 ) -> Result<MigrationStatus, LoaderError> {
189 let ext_id = extension.metadata().id;
190 let defined_migrations = extension.migrations();
191 let applied = self.get_applied_migrations(ext_id).await?;
192
193 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
194
195 let pending: Vec<_> = defined_migrations
196 .iter()
197 .filter(|m| !applied_versions.contains(&m.version))
198 .cloned()
199 .collect();
200
201 Ok(MigrationStatus {
202 extension_id: ext_id.to_string(),
203 total_defined: defined_migrations.len(),
204 total_applied: applied.len(),
205 pending_count: pending.len(),
206 pending,
207 applied,
208 })
209 }
210}
211
212#[derive(Debug)]
213pub struct MigrationStatus {
214 pub extension_id: String,
215 pub total_defined: usize,
216 pub total_applied: usize,
217 pub pending_count: usize,
218 pub pending: Vec<Migration>,
219 pub applied: Vec<AppliedMigration>,
220}