1use std::fs;
2use std::path::{Path, PathBuf};
3
4use rustauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
5use rustauth_core::error::RustAuthError;
6#[cfg(feature = "deadpool-postgres")]
7use rustauth_deadpool_postgres::DeadpoolPostgresAdapter;
8#[cfg(feature = "sqlx")]
9use rustauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
10#[cfg(feature = "tokio-postgres")]
11use rustauth_tokio_postgres::TokioPostgresAdapter;
12use serde::Serialize;
13use sha2::{Digest, Sha256};
14use time::format_description::well_known::Rfc3339;
15use time::OffsetDateTime;
16
17use crate::config::CliConfig;
18use crate::plugins::plugin_migrations_for_config;
19use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
20
21pub fn is_cli_migration_adapter(adapter: &str) -> bool {
22 match adapter {
23 "sqlx" if cfg!(feature = "sqlx") => true,
24 "tokio-postgres" if cfg!(feature = "tokio-postgres") => true,
25 "deadpool-postgres" if cfg!(feature = "deadpool-postgres") => true,
26 _ => false,
27 }
28}
29
30pub fn is_known_cli_migration_adapter(adapter: &str) -> bool {
31 matches!(adapter, "sqlx" | "tokio-postgres" | "deadpool-postgres")
32}
33
34fn is_adapter_feature_disabled(adapter: &str) -> bool {
35 is_known_cli_migration_adapter(adapter) && !is_cli_migration_adapter(adapter)
36}
37
38pub fn cli_migration_adapter_names() -> Vec<&'static str> {
39 let mut adapters = Vec::new();
40 if cfg!(feature = "sqlx") {
41 adapters.push("sqlx");
42 }
43 if cfg!(feature = "tokio-postgres") {
44 adapters.push("tokio-postgres");
45 }
46 if cfg!(feature = "deadpool-postgres") {
47 adapters.push("deadpool-postgres");
48 }
49 adapters
50}
51
52fn is_postgres_provider(provider: &str) -> bool {
53 matches!(provider, "postgres" | "postgresql" | "pg")
54}
55
56#[derive(Debug, thiserror::Error)]
57pub enum DbCliError {
58 #[error("database provider is not configured")]
59 MissingProvider,
60 #[error("database URL environment variable {0} is not set; add it to .env/.env.local next to the project or config file, or export it before running this command")]
61 MissingDatabaseUrl(String),
62 #[error(
63 "unsupported database adapter `{adapter}`; {support}",
64 adapter = .0,
65 support = unsupported_adapter_support_suffix()
66 )]
67 UnsupportedAdapter(String),
68 #[error(
69 "database adapter `{0}` is not enabled in this CLI build; rebuild with the matching \
70 Cargo feature ({1})"
71 )]
72 AdapterFeatureDisabled(String, String),
73 #[error("unsupported database provider `{0}`")]
74 UnsupportedProvider(String),
75 #[error("migration has non-executable warnings; fix schema mismatches before applying")]
76 UnsafeMigration,
77 #[error("A migration for this plan already exists: {0}")]
78 DuplicateMigration(String),
79 #[error("database error: {0}")]
80 RustAuth(#[from] RustAuthError),
81 #[error("failed to write {path}: {source}")]
82 Write {
83 path: PathBuf,
84 source: std::io::Error,
85 },
86 #[error("failed to read {path}: {source}")]
87 Read {
88 path: PathBuf,
89 source: std::io::Error,
90 },
91 #[error("failed to create {path}: {source}")]
92 CreateDir {
93 path: PathBuf,
94 source: std::io::Error,
95 },
96 #[error("failed to format timestamp: {0}")]
97 TimeFormat(#[from] time::error::Format),
98}
99
100#[derive(Debug, Clone, Serialize)]
101pub struct PlanSummary {
102 pub provider: String,
103 pub tables_to_create: usize,
104 pub columns_to_add: usize,
105 pub indexes_to_create: usize,
106 pub warnings: Vec<SchemaMigrationWarning>,
107 pub statements: usize,
108 pub plan_hash: String,
109}
110
111#[derive(Debug, Clone)]
112pub struct PlannedMigration {
113 pub schema: DbSchema,
114 pub plan: SchemaMigrationPlan,
115 pub provider: String,
116}
117
118impl PlannedMigration {
119 pub fn summary(&self) -> PlanSummary {
120 PlanSummary {
121 provider: self.provider.clone(),
122 tables_to_create: self.plan.to_be_created.len(),
123 columns_to_add: self.plan.to_be_added.len(),
124 indexes_to_create: self.plan.indexes_to_be_created.len(),
125 warnings: self.plan.warnings.clone(),
126 statements: self.plan.statements.len(),
127 plan_hash: plan_hash(&self.plan),
128 }
129 }
130}
131
132pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
133 plan_with_base(config, from_empty, None).await
134}
135
136pub async fn plan_with_base(
137 config: &CliConfig,
138 from_empty: bool,
139 cwd: Option<&Path>,
140) -> Result<PlannedMigration, DbCliError> {
141 validate_cli_migration_adapter(config)?;
142 let schema = target_schema(config)?;
143 let provider = config
144 .database
145 .provider
146 .clone()
147 .ok_or(DbCliError::MissingProvider)?;
148
149 let plan = if from_empty {
150 let dialect = dialect_from_provider(&provider)
151 .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
152 full_schema_plan(dialect, &schema)?
153 } else {
154 let database_url = database_url_with_base(config, cwd)?;
155 match config.database.adapter.as_str() {
156 #[cfg(feature = "sqlx")]
157 "sqlx" => plan_with_sqlx(&provider, &database_url, &schema).await?,
158 #[cfg(feature = "tokio-postgres")]
159 "tokio-postgres" => {
160 if !is_postgres_provider(&provider) {
161 return Err(DbCliError::UnsupportedProvider(provider));
162 }
163 TokioPostgresAdapter::connect_with_schema(&database_url, schema.clone())
164 .await?
165 .plan_migrations(&schema)
166 .await?
167 }
168 #[cfg(feature = "deadpool-postgres")]
169 "deadpool-postgres" => {
170 if !is_postgres_provider(&provider) {
171 return Err(DbCliError::UnsupportedProvider(provider));
172 }
173 DeadpoolPostgresAdapter::builder()
174 .database_url(database_url)
175 .schema(schema.clone())
176 .connect()
177 .await?
178 .plan_migrations(&schema)
179 .await?
180 }
181 adapter => return Err(adapter_dispatch_error(adapter)),
182 }
183 };
184
185 Ok(PlannedMigration {
186 schema,
187 plan,
188 provider,
189 })
190}
191
192pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
193 migrate_with_base(config, None).await
194}
195
196pub async fn migrate_with_base(
197 config: &CliConfig,
198 cwd: Option<&Path>,
199) -> Result<PlannedMigration, DbCliError> {
200 let planned = plan_with_base(config, false, cwd).await?;
201 if !planned.plan.warnings.is_empty() {
202 return Err(DbCliError::UnsafeMigration);
203 }
204 let database_url = database_url_with_base(config, cwd)?;
205 let plugin_migrations = plugin_migrations_for_config(&config.plugins.enabled)?;
206 match config.database.adapter.as_str() {
207 #[cfg(feature = "sqlx")]
208 "sqlx" => {
209 run_migrations_with_sqlx(
210 &planned.provider,
211 &database_url,
212 &planned.schema,
213 &plugin_migrations,
214 )
215 .await?;
216 }
217 #[cfg(feature = "tokio-postgres")]
218 "tokio-postgres" => {
219 let adapter =
220 TokioPostgresAdapter::connect_with_schema(&database_url, planned.schema.clone())
221 .await?;
222 adapter.run_migrations(&planned.schema).await?;
223 adapter.run_plugin_migrations(&plugin_migrations).await?;
224 }
225 #[cfg(feature = "deadpool-postgres")]
226 "deadpool-postgres" => {
227 let adapter = DeadpoolPostgresAdapter::builder()
228 .database_url(database_url)
229 .schema(planned.schema.clone())
230 .connect()
231 .await?;
232 adapter.run_migrations(&planned.schema).await?;
233 adapter.run_plugin_migrations(&plugin_migrations).await?;
234 }
235 adapter => return Err(adapter_dispatch_error(adapter)),
236 }
237 Ok(planned)
238}
239
240#[cfg(feature = "sqlx")]
241async fn plan_with_sqlx(
242 provider: &str,
243 database_url: &str,
244 schema: &DbSchema,
245) -> Result<SchemaMigrationPlan, DbCliError> {
246 match provider {
247 "sqlite" | "sqlite3" => {
248 ensure_sqlite_database(database_url)?;
249 SqliteAdapter::connect_with_schema(database_url, schema.clone())
250 .await?
251 .plan_migrations(schema)
252 .await
253 .map_err(Into::into)
254 }
255 "postgres" | "postgresql" | "pg" => {
256 PostgresAdapter::connect_with_schema(database_url, schema.clone())
257 .await?
258 .plan_migrations(schema)
259 .await
260 .map_err(Into::into)
261 }
262 "mysql" => MySqlAdapter::connect_with_schema(database_url, schema.clone())
263 .await?
264 .plan_migrations(schema)
265 .await
266 .map_err(Into::into),
267 other => Err(DbCliError::UnsupportedProvider(other.to_owned())),
268 }
269}
270
271#[cfg(feature = "sqlx")]
272async fn run_migrations_with_sqlx(
273 provider: &str,
274 database_url: &str,
275 schema: &DbSchema,
276 plugin_migrations: &[rustauth_core::plugin::PluginMigration],
277) -> Result<(), DbCliError> {
278 match provider {
279 "sqlite" | "sqlite3" => {
280 ensure_sqlite_database(database_url)?;
281 let adapter = SqliteAdapter::connect_with_schema(database_url, schema.clone()).await?;
282 adapter.run_migrations(schema).await?;
283 adapter.run_plugin_migrations(plugin_migrations).await?;
284 }
285 "postgres" | "postgresql" | "pg" => {
286 let adapter =
287 PostgresAdapter::connect_with_schema(database_url, schema.clone()).await?;
288 adapter.run_migrations(schema).await?;
289 adapter.run_plugin_migrations(plugin_migrations).await?;
290 }
291 "mysql" => {
292 let adapter = MySqlAdapter::connect_with_schema(database_url, schema.clone()).await?;
293 adapter.run_migrations(schema).await?;
294 adapter.run_plugin_migrations(plugin_migrations).await?;
295 }
296 other => return Err(DbCliError::UnsupportedProvider(other.to_owned())),
297 }
298 Ok(())
299}
300
301pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
302 let dialect = dialect_from_provider(&planned.provider)
303 .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
304 let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
305 let schema_hash = schema_hash(&planned.schema)?;
306 let plan_hash = plan_hash(&planned.plan);
307 Ok(format!(
308 "-- RustAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
309 dialect_name(dialect),
310 generated_at,
311 schema_hash,
312 plan_hash,
313 config.project.base_path,
314 planned.plan.compile()
315 ))
316}
317
318pub fn write_migration(
319 config: &CliConfig,
320 planned: &PlannedMigration,
321 output: Option<&Path>,
322 force: bool,
323) -> Result<PathBuf, DbCliError> {
324 write_migration_output(
325 config,
326 planned,
327 output
328 .map(|path| MigrationOutput::Directory(path.to_path_buf()))
329 .unwrap_or(MigrationOutput::Default),
330 force,
331 )
332}
333
334pub enum MigrationOutput {
335 Default,
336 Directory(PathBuf),
337 File(PathBuf),
338}
339
340pub fn write_migration_output(
341 config: &CliConfig,
342 planned: &PlannedMigration,
343 output: MigrationOutput,
344 force: bool,
345) -> Result<PathBuf, DbCliError> {
346 if planned.plan.is_empty() {
347 return Ok(PathBuf::new());
348 }
349 let (dir, explicit_file) = match output {
350 MigrationOutput::Default => (PathBuf::from(&config.database.migrations_dir), None),
351 MigrationOutput::Directory(dir) => (dir, None),
352 MigrationOutput::File(path) => (
353 path.parent()
354 .map(Path::to_path_buf)
355 .unwrap_or_else(|| PathBuf::from(".")),
356 Some(path),
357 ),
358 };
359 let hash = plan_hash(&planned.plan);
360 if !force {
361 if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
362 return Err(DbCliError::DuplicateMigration(
363 existing.display().to_string(),
364 ));
365 }
366 }
367 fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
368 path: dir.clone(),
369 source,
370 })?;
371 let path = explicit_file.unwrap_or_else(|| {
372 dir.join(format!(
373 "{}_{}_{}.sql",
374 filename_timestamp(),
375 normalized_provider(&planned.provider),
376 hash
377 ))
378 });
379 if path.exists() && !force {
380 return Err(DbCliError::DuplicateMigration(path.display().to_string()));
381 }
382 let sql = migration_sql(config, planned)?;
383 fs::write(&path, sql).map_err(|source| DbCliError::Write {
384 path: path.clone(),
385 source,
386 })?;
387 Ok(path)
388}
389
390pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
391 let payload = serde_json::to_vec(schema)
392 .map_err(|error| RustAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
393 Ok(short_hash(&payload))
394}
395
396pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
397 short_hash(plan.compile().as_bytes())
398}
399
400pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
401 database_url_with_base(config, None)
402}
403
404pub fn database_url_with_base(
405 config: &CliConfig,
406 cwd: Option<&Path>,
407) -> Result<String, DbCliError> {
408 std::env::var(&config.database.url_env)
409 .map(|url| normalize_database_url(config.database.provider.as_deref(), &url, cwd))
410 .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
411}
412
413pub fn supports_sql_migrations(config: &CliConfig) -> bool {
414 if !is_cli_migration_adapter(&config.database.adapter) {
415 return false;
416 }
417 match config.database.adapter.as_str() {
418 "sqlx" if cfg!(feature = "sqlx") => config
419 .database
420 .provider
421 .as_deref()
422 .is_some_and(|provider| dialect_from_provider(provider).is_some()),
423 "tokio-postgres" if cfg!(feature = "tokio-postgres") => config
424 .database
425 .provider
426 .as_deref()
427 .is_some_and(is_postgres_provider),
428 "deadpool-postgres" if cfg!(feature = "deadpool-postgres") => config
429 .database
430 .provider
431 .as_deref()
432 .is_some_and(is_postgres_provider),
433 _ => false,
434 }
435}
436
437pub fn unsupported_adapter_exits_successfully(adapter: &str) -> bool {
441 matches!(
442 adapter,
443 "prisma" | "drizzle" | "memory" | "mongodb" | "kysely"
444 )
445}
446
447pub fn unsupported_adapter_guidance(adapter: &str, command: &str) -> String {
448 match adapter {
449 "prisma" => format!(
450 "The {command} command applies RustAuth SQL migrations through the sqlx adapter. \
451 With Prisma configured, run `rustauth db generate` to write `.sql` files, then apply \
452 them with `prisma migrate` or `prisma db push`."
453 ),
454 "drizzle" => format!(
455 "The {command} command applies RustAuth SQL migrations through the sqlx adapter. \
456 With Drizzle configured, run `rustauth db generate` to write `.sql` files, then apply \
457 them with your Drizzle migration workflow."
458 ),
459 "kysely" => format!(
460 "The {command} command uses the sqlx adapter in rustauth.toml. \
461 Set `database.adapter = \"sqlx\"` and configure `database.provider`, or run \
462 `rustauth db generate` and apply the SQL with your existing Kysely tooling."
463 ),
464 "memory" => format!(
465 "The {command} command does not apply migrations for the in-memory adapter. \
466 Use `database.adapter = \"sqlx\"` with a real provider for CLI migrations, or \
467 `rustauth schema print` to inspect the target schema."
468 ),
469 "mongodb" => format!(
470 "The {command} command does not support MongoDB. \
471 Use a SQL provider with {}",
472 enabled_adapter_guidance()
473 ),
474 other => format!(
475 "Unsupported database adapter `{other}` for {command}. \
476 RustAuth CLI migrations require {}",
477 enabled_adapter_guidance()
478 ),
479 }
480}
481
482fn validate_cli_migration_adapter(config: &CliConfig) -> Result<(), DbCliError> {
483 let adapter = config.database.adapter.as_str();
484 if is_adapter_feature_disabled(adapter) {
485 return Err(DbCliError::AdapterFeatureDisabled(
486 adapter.to_owned(),
487 adapter_cargo_feature(adapter).to_owned(),
488 ));
489 }
490 if !is_cli_migration_adapter(adapter) {
491 return Err(DbCliError::UnsupportedAdapter(
492 config.database.adapter.clone(),
493 ));
494 }
495 Ok(())
496}
497
498fn adapter_dispatch_error(adapter: &str) -> DbCliError {
499 if is_adapter_feature_disabled(adapter) {
500 DbCliError::AdapterFeatureDisabled(
501 adapter.to_owned(),
502 adapter_cargo_feature(adapter).to_owned(),
503 )
504 } else {
505 DbCliError::UnsupportedAdapter(adapter.to_owned())
506 }
507}
508
509fn adapter_cargo_feature(adapter: &str) -> &'static str {
510 match adapter {
511 "sqlx" => "sqlx",
512 "tokio-postgres" => "tokio-postgres",
513 "deadpool-postgres" => "deadpool-postgres",
514 _ => "unknown",
515 }
516}
517
518fn unsupported_adapter_support_suffix() -> String {
519 format!("CLI migrations support {}", enabled_adapter_guidance())
520}
521
522fn enabled_adapter_guidance() -> String {
523 let mut parts = Vec::new();
524 if cfg!(feature = "sqlx") {
525 parts.push("`database.adapter = \"sqlx\"` (sqlite, postgres, mysql)".to_owned());
526 }
527 if cfg!(feature = "tokio-postgres") {
528 parts.push("`database.adapter = \"tokio-postgres\"` (postgres only)".to_owned());
529 }
530 if cfg!(feature = "deadpool-postgres") {
531 parts.push("`database.adapter = \"deadpool-postgres\"` (postgres only)".to_owned());
532 }
533 if parts.is_empty() {
534 "no database migration adapters in this CLI build".to_owned()
535 } else {
536 parts.join(", ")
537 }
538}
539
540fn normalize_database_url(provider: Option<&str>, url: &str, cwd: Option<&Path>) -> String {
541 if !matches!(provider, Some("sqlite" | "sqlite3")) {
542 return url.to_owned();
543 }
544 let Some(cwd) = cwd else {
545 return url.to_owned();
546 };
547 let Some(path) = sqlite_path(url) else {
548 return url.to_owned();
549 };
550 if path.as_os_str().is_empty() || path.is_absolute() {
551 return url.to_owned();
552 }
553 format!("sqlite://{}", cwd.join(path).display())
554}
555
556fn short_hash(input: &[u8]) -> String {
557 let digest = Sha256::digest(input);
558 hex::encode(&digest[..8])
559}
560
561fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
562 if !dir.exists() {
563 return Ok(None);
564 }
565 for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
566 path: dir.to_path_buf(),
567 source,
568 })? {
569 let entry = entry.map_err(|source| DbCliError::Read {
570 path: dir.to_path_buf(),
571 source,
572 })?;
573 let path = entry.path();
574 if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
575 continue;
576 }
577 let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
578 path: path.clone(),
579 source,
580 })?;
581 if content.contains(&format!("plan_hash: {hash}")) {
582 return Ok(Some(path));
583 }
584 }
585 Ok(None)
586}
587
588fn filename_timestamp() -> String {
589 let now = OffsetDateTime::now_utc();
590 format!(
591 "{:04}{:02}{:02}{:02}{:02}{:02}",
592 now.year(),
593 u8::from(now.month()),
594 now.day(),
595 now.hour(),
596 now.minute(),
597 now.second()
598 )
599}
600
601fn normalized_provider(provider: &str) -> &str {
602 match provider {
603 "postgresql" | "pg" => "postgres",
604 "sqlite3" => "sqlite",
605 other => other,
606 }
607}
608
609fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
610 let Some(path) = sqlite_path(database_url) else {
611 return Ok(());
612 };
613 if path.as_os_str().is_empty() || path.exists() {
614 return Ok(());
615 }
616 if let Some(parent) = path.parent() {
617 fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
618 path: parent.to_path_buf(),
619 source,
620 })?;
621 }
622 fs::File::create(&path)
623 .map(|_| ())
624 .map_err(|source| DbCliError::Write { path, source })
625}
626
627fn sqlite_path(database_url: &str) -> Option<PathBuf> {
628 if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
629 return None;
630 }
631 database_url
632 .strip_prefix("sqlite://")
633 .or_else(|| database_url.strip_prefix("sqlite:"))
634 .map(PathBuf::from)
635}