systemprompt_database/lifecycle/migrations/
mod.rs1mod down;
7mod exec;
8mod squash;
9mod status;
10
11pub use squash::SquashPlan;
12pub use status::{
13 AppliedMigration, ChecksumDrift, ExtensionMigrationStatus, MigrationResult, MigrationStatus,
14 PendingMigration,
15};
16
17use crate::services::{DatabaseProvider, SqlExecutor};
18use exec::{check_cross_extension_alters, execute_statements_transactional};
19use std::collections::HashSet;
20use systemprompt_extension::{Extension, LoaderError, Migration};
21use tracing::{debug, info, warn};
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct MigrationConfig {
25 pub allow_checksum_drift: bool,
26}
27
28pub struct MigrationService<'a> {
29 db: &'a dyn DatabaseProvider,
30 config: MigrationConfig,
31}
32
33impl std::fmt::Debug for MigrationService<'_> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("MigrationService")
36 .field("config", &self.config)
37 .finish_non_exhaustive()
38 }
39}
40
41impl<'a> MigrationService<'a> {
42 pub fn new(db: &'a dyn DatabaseProvider) -> Self {
43 Self {
44 db,
45 config: MigrationConfig::default(),
46 }
47 }
48
49 #[must_use]
50 pub const fn with_config(mut self, config: MigrationConfig) -> Self {
51 self.config = config;
52 self
53 }
54
55 async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
56 let sql = include_str!("../../../schema/extension_migrations.sql");
57 SqlExecutor::execute_statements_parsed(self.db, sql)
58 .await
59 .map_err(|e| LoaderError::MigrationFailed {
60 extension: "database".to_string(),
61 message: format!("Failed to ensure migrations table exists: {e}"),
62 })
63 }
64
65 pub async fn get_applied_migrations(
66 &self,
67 extension_id: &str,
68 ) -> Result<Vec<AppliedMigration>, LoaderError> {
69 let result = self
70 .db
71 .query_raw_with(
72 &"SELECT extension_id, version, name, checksum, applied_at FROM \
73 extension_migrations WHERE extension_id = $1 ORDER BY version",
74 &[&extension_id],
75 )
76 .await
77 .map_err(|e| LoaderError::MigrationFailed {
78 extension: extension_id.to_string(),
79 message: format!("Failed to query applied migrations: {e}"),
80 })?;
81
82 let migrations = result
83 .rows
84 .iter()
85 .filter_map(|row| {
86 Some(AppliedMigration {
87 extension_id: row.get("extension_id")?.as_str()?.to_string(),
88 version: row.get("version")?.as_i64()? as u32,
89 name: row.get("name")?.as_str()?.to_string(),
90 checksum: row.get("checksum")?.as_str()?.to_string(),
91 applied_at: row
92 .get("applied_at")
93 .and_then(|v| v.as_str().map(String::from)),
94 })
95 })
96 .collect();
97
98 Ok(migrations)
99 }
100
101 pub async fn run_pending_migrations(
102 &self,
103 extension: &dyn Extension,
104 ) -> Result<MigrationResult, LoaderError> {
105 let ext_id = extension.metadata().id;
106 let migrations = extension.migrations();
107
108 if migrations.is_empty() {
109 return Ok(MigrationResult::default());
110 }
111
112 self.ensure_migrations_table_exists().await?;
113
114 let applied = self.get_applied_migrations(ext_id).await?;
115 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
116 let applied_checksums: std::collections::HashMap<u32, &str> = applied
117 .iter()
118 .map(|m| (m.version, m.checksum.as_str()))
119 .collect();
120
121 let mut migrations_run = 0;
122 let mut migrations_skipped = 0;
123
124 for migration in &migrations {
125 if applied_versions.contains(&migration.version) {
126 self.verify_checksum(
127 ext_id,
128 migration,
129 applied_checksums.get(&migration.version).copied(),
130 )?;
131 migrations_skipped += 1;
132 debug!(
133 extension = %ext_id,
134 version = migration.version,
135 "Migration already applied, skipping"
136 );
137 continue;
138 }
139
140 self.execute_migration(extension, migration).await?;
141 migrations_run += 1;
142 }
143
144 if migrations_run > 0 {
145 info!(
146 extension = %ext_id,
147 migrations_run,
148 migrations_skipped,
149 "Migrations completed"
150 );
151 }
152
153 Ok(MigrationResult {
154 migrations_run,
155 migrations_skipped,
156 })
157 }
158
159 fn verify_checksum(
160 &self,
161 ext_id: &str,
162 migration: &Migration,
163 stored: Option<&str>,
164 ) -> Result<(), LoaderError> {
165 let Some(stored_checksum) = stored else {
166 return Ok(());
167 };
168 let current_checksum = migration.checksum();
169 if stored_checksum == current_checksum {
170 return Ok(());
171 }
172 if self.config.allow_checksum_drift {
173 warn!(
174 extension = %ext_id,
175 version = migration.version,
176 name = %migration.name,
177 stored_checksum = %stored_checksum,
178 current_checksum = %current_checksum,
179 "Migration checksum mismatch tolerated by --allow-checksum-drift"
180 );
181 return Ok(());
182 }
183 Err(LoaderError::MigrationFailed {
184 extension: ext_id.to_string(),
185 message: format!(
186 "Migration {ver} ('{name}') has been edited since it was applied (stored checksum \
187 {stored_checksum}, current {current_checksum}). Refusing to proceed. Re-run with \
188 --allow-checksum-drift to override.",
189 ver = migration.version,
190 name = migration.name,
191 ),
192 })
193 }
194
195 async fn execute_migration(
196 &self,
197 extension: &dyn Extension,
198 migration: &Migration,
199 ) -> Result<(), LoaderError> {
200 let ext_id = extension.metadata().id;
201
202 check_cross_extension_alters(extension, migration)?;
203
204 info!(
205 extension = %ext_id,
206 version = migration.version,
207 name = %migration.name,
208 no_transaction = migration.no_transaction,
209 "Running migration"
210 );
211
212 if migration.no_transaction {
213 SqlExecutor::execute_statements_parsed(self.db, migration.sql)
214 .await
215 .map_err(|e| LoaderError::MigrationFailed {
216 extension: ext_id.to_string(),
217 message: format!(
218 "Failed to execute migration {} ({}): {e}",
219 migration.version, migration.name
220 ),
221 })?;
222 } else {
223 let statements = SqlExecutor::parse_sql_statements(migration.sql).map_err(|e| {
224 LoaderError::MigrationFailed {
225 extension: ext_id.to_string(),
226 message: format!(
227 "Failed to parse migration {} ({}): {e}",
228 migration.version, migration.name
229 ),
230 }
231 })?;
232 execute_statements_transactional(self.db, &statements, ext_id, migration).await?;
233 }
234
235 self.record_migration(ext_id, migration).await?;
236
237 Ok(())
238 }
239
240 async fn record_migration(
241 &self,
242 ext_id: &str,
243 migration: &Migration,
244 ) -> Result<(), LoaderError> {
245 let id = format!("{}_{:03}", ext_id, migration.version);
246 let checksum = migration.checksum();
247
248 self.db
249 .execute(
250 &"INSERT INTO extension_migrations (id, extension_id, version, name, checksum) \
251 VALUES ($1, $2, $3, $4, $5)",
252 &[&id, &ext_id, &migration.version, &migration.name, &checksum],
253 )
254 .await
255 .map_err(|e| LoaderError::MigrationFailed {
256 extension: ext_id.to_string(),
257 message: format!("Failed to record migration: {e}"),
258 })?;
259
260 Ok(())
261 }
262}