Skip to main content

systemprompt_database/lifecycle/migrations/
squash.rs

1//! Collapsing a contiguous range of applied migrations into a single
2//! version-0 baseline row.
3
4use super::MigrationService;
5use std::collections::HashSet;
6use systemprompt_extension::{Extension, LoaderError, Migration};
7use tracing::info;
8
9#[derive(Debug, Clone)]
10pub struct SquashPlan {
11    pub extension_id: String,
12    pub through: u32,
13    pub baseline_name: String,
14    pub baseline_sql: String,
15    pub baseline_checksum: String,
16    pub source_versions: Vec<u32>,
17    pub already_applied_versions: Vec<u32>,
18    pub applied: bool,
19}
20
21fn baseline_checksum(sql: &str) -> String {
22    use std::hash::{Hash, Hasher};
23    let mut hasher = std::collections::hash_map::DefaultHasher::new();
24    sql.hash(&mut hasher);
25    format!("{:x}", hasher.finish())
26}
27
28fn collect_squash_range<'m>(
29    ext_id: &str,
30    migrations: &'m [Migration],
31    through: u32,
32) -> Result<Vec<&'m Migration>, LoaderError> {
33    if through == 0 {
34        return Err(LoaderError::MigrationFailed {
35            extension: ext_id.to_string(),
36            message: "--through must be >= 1; version 0 is reserved for the squash baseline"
37                .to_string(),
38        });
39    }
40
41    let to_squash: Vec<&Migration> = migrations
42        .iter()
43        .filter(|m| m.version >= 1 && m.version <= through)
44        .collect();
45
46    if to_squash.is_empty() {
47        return Err(LoaderError::MigrationFailed {
48            extension: ext_id.to_string(),
49            message: format!(
50                "No migrations in range 1..={through} are defined for extension '{ext_id}'"
51            ),
52        });
53    }
54
55    let mut covered: Vec<u32> = to_squash.iter().map(|m| m.version).collect();
56    covered.sort_unstable();
57    if covered != (1..=through).collect::<Vec<u32>>() {
58        return Err(LoaderError::MigrationFailed {
59            extension: ext_id.to_string(),
60            message: format!(
61                "Migrations 1..={through} are not contiguous for extension '{ext_id}': have \
62                 {covered:?}"
63            ),
64        });
65    }
66
67    Ok(to_squash)
68}
69
70fn build_baseline_sql(to_squash: &[&Migration]) -> String {
71    let mut baseline_sql = String::new();
72    for m in to_squash {
73        baseline_sql.push_str(&format!(
74            "-- migration {ver:03}: {name}\n",
75            ver = m.version,
76            name = m.name
77        ));
78        baseline_sql.push_str(m.sql);
79        if !m.sql.ends_with('\n') {
80            baseline_sql.push('\n');
81        }
82        baseline_sql.push('\n');
83    }
84    baseline_sql
85}
86
87impl MigrationService<'_> {
88    pub async fn squash_through(
89        &self,
90        extension: &dyn Extension,
91        through: u32,
92        apply: bool,
93    ) -> Result<SquashPlan, LoaderError> {
94        let ext_id = extension.metadata().id;
95
96        let mut migrations = extension.migrations();
97        migrations.sort_by_key(|m| m.version);
98        let to_squash = collect_squash_range(ext_id, &migrations, through)?;
99
100        let baseline_sql = build_baseline_sql(&to_squash);
101        let checksum = baseline_checksum(&baseline_sql);
102        let baseline_name = format!("baseline_v{through}");
103        let covered: Vec<u32> = to_squash.iter().map(|m| m.version).collect();
104
105        self.verify_range_applied(ext_id, through).await?;
106
107        let plan = SquashPlan {
108            extension_id: ext_id.to_string(),
109            through,
110            baseline_name: baseline_name.clone(),
111            baseline_sql,
112            baseline_checksum: checksum.clone(),
113            source_versions: covered,
114            already_applied_versions: (1..=through).collect(),
115            applied: false,
116        };
117
118        if !apply {
119            return Ok(plan);
120        }
121
122        self.apply_squash_rows(ext_id, through, &baseline_name, &checksum)
123            .await?;
124
125        Ok(SquashPlan {
126            applied: true,
127            ..plan
128        })
129    }
130
131    async fn verify_range_applied(&self, ext_id: &str, through: u32) -> Result<(), LoaderError> {
132        self.ensure_migrations_table_exists().await?;
133        let applied = self.get_applied_migrations(ext_id).await?;
134        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
135        let not_applied: Vec<u32> = (1..=through)
136            .filter(|v| !applied_versions.contains(v))
137            .collect();
138        if not_applied.is_empty() {
139            return Ok(());
140        }
141        Err(LoaderError::MigrationFailed {
142            extension: ext_id.to_string(),
143            message: format!(
144                "Refusing to squash through {through}: extension '{ext_id}' has not applied \
145                 migrations {not_applied:?}. Squashing would retire them behind the baseline \
146                 without ever running them. Apply migrations 1..={through} first."
147            ),
148        })
149    }
150
151    async fn apply_squash_rows(
152        &self,
153        ext_id: &str,
154        through: u32,
155        baseline_name: &str,
156        checksum: &str,
157    ) -> Result<(), LoaderError> {
158        let baseline_id = format!("{ext_id}_000");
159
160        self.db
161            .execute(
162                &"INSERT INTO extension_migrations (id, extension_id, version, name, checksum) \
163                  VALUES ($1, $2, 0, $3, $4) ON CONFLICT (extension_id, version) DO UPDATE SET \
164                  name = EXCLUDED.name, checksum = EXCLUDED.checksum",
165                &[&baseline_id, &ext_id, &baseline_name, &checksum],
166            )
167            .await
168            .map_err(|e| LoaderError::MigrationFailed {
169                extension: ext_id.to_string(),
170                message: format!("Failed to record baseline migration row: {e}"),
171            })?;
172
173        self.db
174            .execute(
175                &"DELETE FROM extension_migrations WHERE extension_id = $1 AND version BETWEEN 1 \
176                  AND $2",
177                &[&ext_id, &through],
178            )
179            .await
180            .map_err(|e| LoaderError::MigrationFailed {
181                extension: ext_id.to_string(),
182                message: format!("Failed to retire squashed migration rows: {e}"),
183            })?;
184
185        info!(
186            extension = %ext_id,
187            through,
188            baseline_name = %baseline_name,
189            "Squash applied: baseline row inserted, source rows retired"
190        );
191
192        Ok(())
193    }
194}