1mod migration;
25
26use crate::migration::{AppliedMigration, Migration};
27use anyhow::{Context, Result};
28use scylla::client::session::Session;
29use std::borrow::Cow;
30use std::collections::HashMap;
31use time::OffsetDateTime;
32use tokio::fs;
33
34#[derive(Debug)]
36pub struct Migrator<'a> {
37 session: &'a Session,
38 migrations_src: &'a str,
39}
40
41impl<'a> Migrator<'a> {
42 pub fn new(session: &'a Session, migrations_src: &'a str) -> Self {
44 Self {
45 session,
46 migrations_src,
47 }
48 }
49
50 async fn create_public_keyspace(&self) -> Result<()> {
51 self.session
52 .query_unpaged(
53 r#"
54 CREATE KEYSPACE IF NOT EXISTS public
55 WITH REPLICATION = {'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}
56 "#,
57 &[],
58 )
59 .await?;
60 self.session.await_schema_agreement().await?;
61 Ok(())
62 }
63
64 async fn create_migration_table(&self) -> Result<()> {
65 self.session
66 .query_unpaged(
67 r#"CREATE TABLE IF NOT EXISTS public.migrations (
68 version bigint,
69 checksum blob,
70 description text,
71 applied_at timestamp,
72 PRIMARY KEY (version, checksum)
73 )"#,
74 &[],
75 )
76 .await?;
77 self.session.await_schema_agreement().await?;
78 Ok(())
79 }
80
81 async fn record_migration(&self, migration: &Migration) -> Result<()> {
82 self.session
83 .query_unpaged(
84 r#"
85 INSERT INTO public.migrations
86 (version, description, checksum, applied_at)
87 VALUES (?, ?, ?, ?)
88 "#,
89 (
90 migration.version,
91 migration.description.as_ref(),
92 migration.checksum.as_ref(),
93 OffsetDateTime::now_utc(),
94 ),
95 )
96 .await?;
97 Ok(())
98 }
99
100 async fn get_applied_migrations(&self) -> Result<HashMap<i64, AppliedMigration>> {
101 let query_rows = self
102 .session
103 .query_unpaged("SELECT version, checksum FROM public.migrations", ())
104 .await?
105 .into_rows_result()
106 .context("Failed to get rows from migrations table")?;
107
108 let mut map = HashMap::new();
109
110 for row in query_rows.rows()? {
111 let (v, c): (i64, Vec<u8>) = row?;
112 map.insert(
113 v,
114 AppliedMigration {
115 checksum: Cow::Owned(c),
116 },
117 );
118 }
119
120 Ok(map)
121 }
122
123 async fn load_migrations(&self) -> Result<Vec<Migration>> {
124 let mut entries = fs::read_dir(&self.migrations_src)
125 .await
126 .context("Could not find migrations directory")?;
127
128 let mut migrations = Vec::new();
129
130 while let Some(entry) = entries.next_entry().await? {
131 if let Ok(meta) = entry.metadata().await {
132 if !meta.is_file() {
133 continue;
134 }
135
136 let path = entry.path();
137 if path.extension().and_then(|ext| ext.to_str()) != Some("cql") {
138 continue;
139 }
140
141 let filename = entry.file_name().to_string_lossy().into_owned();
142
143 let version = filename
144 .split('_')
145 .next()
146 .and_then(|v| v.parse::<i64>().ok())
147 .ok_or_else(|| {
148 anyhow::anyhow!("Invalid migration filename format: {}", filename)
149 })?;
150
151 let cql = fs::read_to_string(path).await?;
152
153 migrations.push(Migration::new(
154 version,
155 Cow::Owned(entry.file_name().to_string_lossy().to_string()),
156 Cow::Owned(cql),
157 ));
158 }
159 }
160
161 migrations.sort_by(|a, b| a.version.cmp(&b.version));
163 Ok(migrations)
164 }
165
166 pub async fn run(&self) -> Result<()> {
173 self.create_public_keyspace().await?;
174 self.create_migration_table().await?;
175
176 let migrations = self.load_migrations().await?;
177 let applied_migrations = self.get_applied_migrations().await?;
178 for migration in migrations {
179 if let Some(applied) = applied_migrations.get(&migration.version) {
180 if applied.checksum.as_ref() == migration.checksum.as_ref() {
181 println!("Migration {} already applied", migration.description);
182 continue;
183 } else {
184 println!(
186 "Migration {} has changes, applying updates",
187 migration.description
188 );
189 }
190 }
191
192 migration.up(self.session).await?;
194 self.record_migration(&migration).await?;
195 println!(
196 "Applied {}/migrate {}",
197 migration.version, migration.description
198 );
199 }
200
201 Ok(())
202 }
203}