systemprompt_database/lifecycle/
migrations.rs1use crate::services::{DatabaseProvider, SqlExecutor};
5use std::collections::HashSet;
6use systemprompt_extension::{Extension, LoaderError, Migration};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone)]
10pub struct AppliedMigration {
11 pub extension_id: String,
12 pub version: u32,
13 pub name: String,
14 pub checksum: String,
15}
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct MigrationResult {
19 pub migrations_run: usize,
20 pub migrations_skipped: usize,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct MigrationConfig {
25 pub allow_checksum_drift: bool,
26}
27
28pub struct MigrationService<'a> {
29 db: &'a dyn DatabaseProvider,
30 config: MigrationConfig,
31}
32
33impl std::fmt::Debug for MigrationService<'_> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("MigrationService")
36 .field("config", &self.config)
37 .finish_non_exhaustive()
38 }
39}
40
41impl<'a> MigrationService<'a> {
42 pub fn new(db: &'a dyn DatabaseProvider) -> Self {
43 Self {
44 db,
45 config: MigrationConfig::default(),
46 }
47 }
48
49 #[must_use]
50 pub const fn with_config(mut self, config: MigrationConfig) -> Self {
51 self.config = config;
52 self
53 }
54
55 async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
56 let sql = include_str!("../../schema/extension_migrations.sql");
57 SqlExecutor::execute_statements_parsed(self.db, sql)
58 .await
59 .map_err(|e| LoaderError::MigrationFailed {
60 extension: "database".to_string(),
61 message: format!("Failed to ensure migrations table exists: {e}"),
62 })
63 }
64
65 pub async fn get_applied_migrations(
66 &self,
67 extension_id: &str,
68 ) -> Result<Vec<AppliedMigration>, LoaderError> {
69 let result = self
70 .db
71 .query_raw_with(
72 &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
73 extension_id = $1 ORDER BY version",
74 vec![serde_json::Value::String(extension_id.to_string())],
75 )
76 .await
77 .map_err(|e| LoaderError::MigrationFailed {
78 extension: extension_id.to_string(),
79 message: format!("Failed to query applied migrations: {e}"),
80 })?;
81
82 let migrations = result
83 .rows
84 .iter()
85 .filter_map(|row| {
86 Some(AppliedMigration {
87 extension_id: row.get("extension_id")?.as_str()?.to_string(),
88 version: row.get("version")?.as_i64()? as u32,
89 name: row.get("name")?.as_str()?.to_string(),
90 checksum: row.get("checksum")?.as_str()?.to_string(),
91 })
92 })
93 .collect();
94
95 Ok(migrations)
96 }
97
98 pub async fn run_pending_migrations(
99 &self,
100 extension: &dyn Extension,
101 ) -> Result<MigrationResult, LoaderError> {
102 let ext_id = extension.metadata().id;
103 let migrations = extension.migrations();
104
105 if migrations.is_empty() {
106 return Ok(MigrationResult::default());
107 }
108
109 self.ensure_migrations_table_exists().await?;
110
111 let applied = self.get_applied_migrations(ext_id).await?;
112 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
113 let applied_checksums: std::collections::HashMap<u32, &str> = applied
114 .iter()
115 .map(|m| (m.version, m.checksum.as_str()))
116 .collect();
117
118 let mut migrations_run = 0;
119 let mut migrations_skipped = 0;
120
121 for migration in &migrations {
122 if applied_versions.contains(&migration.version) {
123 let current_checksum = migration.checksum();
124 if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
125 if stored_checksum != current_checksum {
126 if self.config.allow_checksum_drift {
127 warn!(
128 extension = %ext_id,
129 version = migration.version,
130 name = %migration.name,
131 stored_checksum = %stored_checksum,
132 current_checksum = %current_checksum,
133 "Migration checksum mismatch tolerated by --allow-checksum-drift"
134 );
135 } else {
136 return Err(LoaderError::MigrationFailed {
137 extension: ext_id.to_string(),
138 message: format!(
139 "Migration {ver} ('{name}') has been edited since it was \
140 applied (stored checksum {stored_checksum}, current \
141 {current_checksum}). Refusing to proceed. Re-run with \
142 --allow-checksum-drift to override.",
143 ver = migration.version,
144 name = migration.name,
145 ),
146 });
147 }
148 }
149 }
150 migrations_skipped += 1;
151 debug!(
152 extension = %ext_id,
153 version = migration.version,
154 "Migration already applied, skipping"
155 );
156 continue;
157 }
158
159 self.execute_migration(ext_id, migration).await?;
160 migrations_run += 1;
161 }
162
163 if migrations_run > 0 {
164 info!(
165 extension = %ext_id,
166 migrations_run,
167 migrations_skipped,
168 "Migrations completed"
169 );
170 }
171
172 Ok(MigrationResult {
173 migrations_run,
174 migrations_skipped,
175 })
176 }
177
178 async fn execute_migration(
179 &self,
180 ext_id: &str,
181 migration: &Migration,
182 ) -> Result<(), LoaderError> {
183 info!(
184 extension = %ext_id,
185 version = migration.version,
186 name = %migration.name,
187 "Running migration"
188 );
189
190 SqlExecutor::execute_statements_parsed(self.db, migration.sql)
191 .await
192 .map_err(|e| LoaderError::MigrationFailed {
193 extension: ext_id.to_string(),
194 message: format!(
195 "Failed to execute migration {} ({}): {e}",
196 migration.version, migration.name
197 ),
198 })?;
199
200 self.record_migration(ext_id, migration).await?;
201
202 Ok(())
203 }
204
205 async fn record_migration(
206 &self,
207 ext_id: &str,
208 migration: &Migration,
209 ) -> Result<(), LoaderError> {
210 let id = format!("{}_{:03}", ext_id, migration.version);
211 let checksum = migration.checksum();
212 let name = migration.name.replace('\'', "''");
213
214 let sql = format!(
215 "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
216 ('{}', '{}', {}, '{}', '{}')",
217 id, ext_id, migration.version, name, checksum
218 );
219
220 self.db
221 .execute_raw(&sql)
222 .await
223 .map_err(|e| LoaderError::MigrationFailed {
224 extension: ext_id.to_string(),
225 message: format!("Failed to record migration: {e}"),
226 })?;
227
228 Ok(())
229 }
230
231 pub async fn get_migration_status(
232 &self,
233 extension: &dyn Extension,
234 ) -> Result<MigrationStatus, LoaderError> {
235 self.ensure_migrations_table_exists().await?;
236
237 let ext_id = extension.metadata().id;
238 let defined_migrations = extension.migrations();
239 let applied = self.get_applied_migrations(ext_id).await?;
240
241 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
242
243 let pending: Vec<_> = defined_migrations
244 .iter()
245 .filter(|m| !applied_versions.contains(&m.version))
246 .cloned()
247 .collect();
248
249 Ok(MigrationStatus {
250 extension_id: ext_id.to_string(),
251 total_defined: defined_migrations.len(),
252 total_applied: applied.len(),
253 pending_count: pending.len(),
254 pending,
255 applied,
256 })
257 }
258}
259
260#[derive(Debug)]
261pub struct MigrationStatus {
262 pub extension_id: String,
263 pub total_defined: usize,
264 pub total_applied: usize,
265 pub pending_count: usize,
266 pub pending: Vec<Migration>,
267 pub applied: Vec<AppliedMigration>,
268}