1use std::collections::HashMap;
7use std::fmt;
8use std::path::PathBuf;
9
10use serde::Deserialize;
11
12use crate::error::{Result, WaypointError};
13
14macro_rules! apply_option {
18 ($opt:expr => $target:expr) => {
19 if let Some(v) = $opt {
20 $target = v;
21 }
22 };
23}
24
25macro_rules! apply_option_some {
29 ($opt:expr => $target:expr) => {
30 if let Some(v) = $opt {
31 $target = Some(v);
32 }
33 };
34}
35
36macro_rules! apply_option_clone {
40 ($opt:expr => $target:expr) => {
41 if let Some(ref v) = $opt {
42 $target = v.clone();
43 }
44 };
45}
46
47macro_rules! apply_option_some_clone {
51 ($opt:expr => $target:expr) => {
52 if let Some(ref v) = $opt {
53 $target = Some(v.clone());
54 }
55 };
56}
57
58#[derive(Debug, Clone, Default, PartialEq, Eq)]
60pub enum SslMode {
61 Disable,
63 #[default]
65 Prefer,
66 Require,
68}
69
70impl std::str::FromStr for SslMode {
71 type Err = WaypointError;
72
73 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
74 match s.to_lowercase().as_str() {
75 "disable" | "disabled" => Ok(SslMode::Disable),
76 "prefer" => Ok(SslMode::Prefer),
77 "require" | "required" => Ok(SslMode::Require),
78 _ => Err(WaypointError::ConfigError(format!(
79 "Invalid SSL mode '{}'. Use 'disable', 'prefer', or 'require'.",
80 s
81 ))),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct WaypointConfig {
89 pub database: DatabaseConfig,
91 pub migrations: MigrationSettings,
93 pub hooks: HooksConfig,
95 pub placeholders: HashMap<String, String>,
97 pub lint: LintConfig,
99 pub snapshots: crate::commands::snapshot::SnapshotConfig,
101 pub preflight: crate::preflight::PreflightConfig,
103 pub multi_database: Option<Vec<crate::multi::NamedDatabaseConfig>>,
105 pub guards: crate::guard::GuardsConfig,
107 pub reversals: crate::reversal::ReversalConfig,
109 pub safety: crate::safety::SafetyConfig,
111 pub advisor: crate::advisor::AdvisorConfig,
113 pub simulation: SimulationConfig,
115}
116
117#[derive(Clone)]
119pub struct DatabaseConfig {
120 pub url: Option<String>,
122 pub host: Option<String>,
124 pub port: Option<u16>,
126 pub user: Option<String>,
128 pub password: Option<String>,
130 pub database: Option<String>,
132 pub connect_retries: u32,
134 pub ssl_mode: SslMode,
136 pub connect_timeout_secs: u32,
138 pub statement_timeout_secs: u32,
140 pub keepalive_secs: u32,
142}
143
144impl Default for DatabaseConfig {
145 fn default() -> Self {
146 Self {
147 url: None,
148 host: None,
149 port: None,
150 user: None,
151 password: None,
152 database: None,
153 connect_retries: 0,
154 ssl_mode: SslMode::Prefer,
155 connect_timeout_secs: 30,
156 statement_timeout_secs: 0,
157 keepalive_secs: 120,
158 }
159 }
160}
161
162impl fmt::Debug for DatabaseConfig {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 f.debug_struct("DatabaseConfig")
165 .field("url", &self.url.as_ref().map(|_| "[REDACTED]"))
166 .field("host", &self.host)
167 .field("port", &self.port)
168 .field("user", &self.user)
169 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
170 .field("database", &self.database)
171 .field("connect_retries", &self.connect_retries)
172 .field("ssl_mode", &self.ssl_mode)
173 .field("connect_timeout_secs", &self.connect_timeout_secs)
174 .field("statement_timeout_secs", &self.statement_timeout_secs)
175 .field("keepalive_secs", &self.keepalive_secs)
176 .finish()
177 }
178}
179
180#[derive(Debug, Clone, Default)]
182pub struct HooksConfig {
183 pub before_migrate: Vec<PathBuf>,
185 pub after_migrate: Vec<PathBuf>,
187 pub before_each_migrate: Vec<PathBuf>,
189 pub after_each_migrate: Vec<PathBuf>,
191}
192
193#[derive(Debug, Clone, Default)]
195pub struct LintConfig {
196 pub disabled_rules: Vec<String>,
198}
199
200#[derive(Debug, Clone)]
202pub struct MigrationSettings {
203 pub locations: Vec<PathBuf>,
205 pub table: String,
207 pub schema: String,
209 pub out_of_order: bool,
211 pub validate_on_migrate: bool,
213 pub clean_enabled: bool,
215 pub baseline_version: String,
217 pub installed_by: Option<String>,
219 pub environment: Option<String>,
221 pub dependency_ordering: bool,
223 pub show_progress: bool,
225 pub batch_transaction: bool,
227}
228
229impl Default for MigrationSettings {
230 fn default() -> Self {
231 Self {
232 locations: vec![PathBuf::from("db/migrations")],
233 table: "waypoint_schema_history".to_string(),
234 schema: "public".to_string(),
235 out_of_order: false,
236 validate_on_migrate: true,
237 clean_enabled: false,
238 baseline_version: "1".to_string(),
239 installed_by: None,
240 environment: None,
241 dependency_ordering: false,
242 show_progress: true,
243 batch_transaction: false,
244 }
245 }
246}
247
248#[derive(Debug, Clone, Default)]
250pub struct SimulationConfig {
251 pub simulate_before_migrate: bool,
253}
254
255#[derive(Deserialize, Default)]
258struct TomlConfig {
259 database: Option<TomlDatabaseConfig>,
260 migrations: Option<TomlMigrationSettings>,
261 hooks: Option<TomlHooksConfig>,
262 placeholders: Option<HashMap<String, String>>,
263 lint: Option<TomlLintConfig>,
264 snapshots: Option<TomlSnapshotConfig>,
265 preflight: Option<TomlPreflightConfig>,
266 databases: Option<Vec<TomlNamedDatabaseConfig>>,
267 guards: Option<TomlGuardsConfig>,
268 reversals: Option<TomlReversalConfig>,
269 safety: Option<TomlSafetyConfig>,
270 advisor: Option<TomlAdvisorConfig>,
271 simulation: Option<TomlSimulationConfig>,
272}
273
274#[derive(Deserialize, Default)]
275struct TomlDatabaseConfig {
276 url: Option<String>,
277 host: Option<String>,
278 port: Option<u16>,
279 user: Option<String>,
280 password: Option<String>,
281 database: Option<String>,
282 connect_retries: Option<u32>,
283 ssl_mode: Option<String>,
284 connect_timeout: Option<u32>,
285 statement_timeout: Option<u32>,
286 keepalive: Option<u32>,
287}
288
289#[derive(Deserialize, Default)]
290struct TomlMigrationSettings {
291 locations: Option<Vec<String>>,
292 table: Option<String>,
293 schema: Option<String>,
294 out_of_order: Option<bool>,
295 validate_on_migrate: Option<bool>,
296 clean_enabled: Option<bool>,
297 baseline_version: Option<String>,
298 installed_by: Option<String>,
299 environment: Option<String>,
300 dependency_ordering: Option<bool>,
301 show_progress: Option<bool>,
302 batch_transaction: Option<bool>,
303}
304
305#[derive(Deserialize, Default)]
306struct TomlLintConfig {
307 disabled_rules: Option<Vec<String>>,
308}
309
310#[derive(Deserialize, Default)]
311struct TomlSnapshotConfig {
312 directory: Option<String>,
313 auto_snapshot_on_migrate: Option<bool>,
314 max_snapshots: Option<usize>,
315}
316
317#[derive(Deserialize, Default)]
318struct TomlPreflightConfig {
319 enabled: Option<bool>,
320 max_replication_lag_mb: Option<i64>,
321 long_query_threshold_secs: Option<i64>,
322}
323
324#[derive(Deserialize, Default)]
325struct TomlNamedDatabaseConfig {
326 name: Option<String>,
327 url: Option<String>,
328 depends_on: Option<Vec<String>>,
329 migrations: Option<TomlMigrationSettings>,
330 hooks: Option<TomlHooksConfig>,
331 placeholders: Option<HashMap<String, String>>,
332}
333
334#[derive(Deserialize, Default)]
335struct TomlHooksConfig {
336 before_migrate: Option<Vec<String>>,
337 after_migrate: Option<Vec<String>>,
338 before_each_migrate: Option<Vec<String>>,
339 after_each_migrate: Option<Vec<String>>,
340}
341
342#[derive(Deserialize, Default)]
343struct TomlGuardsConfig {
344 on_require_fail: Option<String>,
345}
346
347#[derive(Deserialize, Default)]
348struct TomlReversalConfig {
349 enabled: Option<bool>,
350 warn_data_loss: Option<bool>,
351}
352
353#[derive(Deserialize, Default)]
354struct TomlSafetyConfig {
355 enabled: Option<bool>,
356 block_on_danger: Option<bool>,
357 large_table_threshold: Option<i64>,
358 huge_table_threshold: Option<i64>,
359}
360
361#[derive(Deserialize, Default)]
362struct TomlAdvisorConfig {
363 run_after_migrate: Option<bool>,
364 disabled_rules: Option<Vec<String>>,
365}
366
367#[derive(Deserialize, Default)]
368struct TomlSimulationConfig {
369 simulate_before_migrate: Option<bool>,
370}
371
372#[derive(Debug, Default, Clone)]
374pub struct CliOverrides {
375 pub url: Option<String>,
377 pub schema: Option<String>,
379 pub table: Option<String>,
381 pub locations: Option<Vec<PathBuf>>,
383 pub out_of_order: Option<bool>,
385 pub validate_on_migrate: Option<bool>,
387 pub baseline_version: Option<String>,
389 pub connect_retries: Option<u32>,
391 pub ssl_mode: Option<String>,
393 pub connect_timeout: Option<u32>,
395 pub statement_timeout: Option<u32>,
397 pub environment: Option<String>,
399 pub dependency_ordering: Option<bool>,
401 pub keepalive: Option<u32>,
403 pub batch_transaction: Option<bool>,
405}
406
407impl WaypointConfig {
408 pub fn load(config_path: Option<&str>, overrides: &CliOverrides) -> Result<Self> {
414 let mut config = WaypointConfig::default();
415
416 let toml_path = config_path.unwrap_or("waypoint.toml");
418 if let Ok(content) = std::fs::read_to_string(toml_path) {
419 #[cfg(unix)]
421 {
422 use std::os::unix::fs::PermissionsExt;
423 if let Ok(meta) = std::fs::metadata(toml_path) {
424 let mode = meta.permissions().mode();
425 if mode & 0o077 != 0 {
426 log::warn!("Config file has overly permissive permissions. Consider chmod 600.; path={}, mode={:o}", toml_path, mode);
427 }
428 }
429 }
430 let toml_config: TomlConfig = toml::from_str(&content).map_err(|e| {
431 WaypointError::ConfigError(format!(
432 "Failed to parse config file '{}': {}",
433 toml_path, e
434 ))
435 })?;
436 config.apply_toml(toml_config);
437 } else if config_path.is_some() {
438 return Err(WaypointError::ConfigError(format!(
440 "Config file '{}' not found",
441 toml_path
442 )));
443 }
444
445 config.apply_env();
447
448 config.apply_cli(overrides);
450
451 crate::db::validate_identifier(&config.migrations.schema)?;
453 crate::db::validate_identifier(&config.migrations.table)?;
454
455 if config.database.connect_retries > 20 {
457 config.database.connect_retries = 20;
458 log::warn!("connect_retries capped at 20");
459 }
460
461 Ok(config)
462 }
463
464 fn apply_toml(&mut self, toml: TomlConfig) {
465 if let Some(db) = toml.database {
466 apply_option_some!(db.url => self.database.url);
467 apply_option_some!(db.host => self.database.host);
468 apply_option_some!(db.port => self.database.port);
469 apply_option_some!(db.user => self.database.user);
470 apply_option_some!(db.password => self.database.password);
471 apply_option_some!(db.database => self.database.database);
472 apply_option!(db.connect_retries => self.database.connect_retries);
473 if let Some(v) = db.ssl_mode {
474 match v.parse() {
475 Ok(mode) => self.database.ssl_mode = mode,
476 Err(_) => log::warn!(
477 "Invalid ssl_mode '{}' in config, using default 'prefer'. Valid values: disable, prefer, require",
478 v
479 ),
480 }
481 }
482 apply_option!(db.connect_timeout => self.database.connect_timeout_secs);
483 apply_option!(db.statement_timeout => self.database.statement_timeout_secs);
484 apply_option!(db.keepalive => self.database.keepalive_secs);
485 }
486
487 if let Some(m) = toml.migrations {
488 if let Some(v) = m.locations {
489 self.migrations.locations = v.into_iter().map(|s| normalize_location(&s)).collect();
490 }
491 apply_option!(m.table => self.migrations.table);
492 apply_option!(m.schema => self.migrations.schema);
493 apply_option!(m.out_of_order => self.migrations.out_of_order);
494 apply_option!(m.validate_on_migrate => self.migrations.validate_on_migrate);
495 apply_option!(m.clean_enabled => self.migrations.clean_enabled);
496 apply_option!(m.baseline_version => self.migrations.baseline_version);
497 apply_option_some!(m.installed_by => self.migrations.installed_by);
498 apply_option_some!(m.environment => self.migrations.environment);
499 apply_option!(m.dependency_ordering => self.migrations.dependency_ordering);
500 apply_option!(m.show_progress => self.migrations.show_progress);
501 apply_option!(m.batch_transaction => self.migrations.batch_transaction);
502 }
503
504 if let Some(h) = toml.hooks {
505 if let Some(v) = h.before_migrate {
506 self.hooks.before_migrate = v.into_iter().map(PathBuf::from).collect();
507 }
508 if let Some(v) = h.after_migrate {
509 self.hooks.after_migrate = v.into_iter().map(PathBuf::from).collect();
510 }
511 if let Some(v) = h.before_each_migrate {
512 self.hooks.before_each_migrate = v.into_iter().map(PathBuf::from).collect();
513 }
514 if let Some(v) = h.after_each_migrate {
515 self.hooks.after_each_migrate = v.into_iter().map(PathBuf::from).collect();
516 }
517 }
518
519 if let Some(p) = toml.placeholders {
520 self.placeholders.extend(p);
521 }
522
523 if let Some(l) = toml.lint {
524 apply_option!(l.disabled_rules => self.lint.disabled_rules);
525 }
526
527 if let Some(s) = toml.snapshots {
528 if let Some(v) = s.directory {
529 self.snapshots.directory = PathBuf::from(v);
530 }
531 apply_option!(s.auto_snapshot_on_migrate => self.snapshots.auto_snapshot_on_migrate);
532 apply_option!(s.max_snapshots => self.snapshots.max_snapshots);
533 }
534
535 if let Some(p) = toml.preflight {
536 apply_option!(p.enabled => self.preflight.enabled);
537 apply_option!(p.max_replication_lag_mb => self.preflight.max_replication_lag_mb);
538 apply_option!(p.long_query_threshold_secs => self.preflight.long_query_threshold_secs);
539 }
540
541 if let Some(g) = toml.guards {
542 if let Some(v) = g.on_require_fail {
543 match v.parse() {
544 Ok(policy) => self.guards.on_require_fail = policy,
545 Err(_) => log::warn!(
546 "Invalid on_require_fail '{}' in config, using default 'error'. Valid values: error, warn, skip",
547 v
548 ),
549 }
550 }
551 }
552
553 if let Some(r) = toml.reversals {
554 apply_option!(r.enabled => self.reversals.enabled);
555 apply_option!(r.warn_data_loss => self.reversals.warn_data_loss);
556 }
557
558 if let Some(s) = toml.safety {
559 apply_option!(s.enabled => self.safety.enabled);
560 apply_option!(s.block_on_danger => self.safety.block_on_danger);
561 apply_option!(s.large_table_threshold => self.safety.large_table_threshold);
562 apply_option!(s.huge_table_threshold => self.safety.huge_table_threshold);
563 }
564
565 if let Some(a) = toml.advisor {
566 apply_option!(a.run_after_migrate => self.advisor.run_after_migrate);
567 apply_option!(a.disabled_rules => self.advisor.disabled_rules);
568 }
569
570 if let Some(s) = toml.simulation {
571 apply_option!(s.simulate_before_migrate => self.simulation.simulate_before_migrate);
572 }
573
574 if let Some(databases) = toml.databases {
575 let mut named_dbs = Vec::new();
576 for db in databases {
577 let name = db.name.unwrap_or_default();
578 let mut db_config = DatabaseConfig::default();
579 apply_option_some!(db.url => db_config.url);
580 let env_url_key = format!("WAYPOINT_DB_{}_URL", name.to_uppercase());
582 if let Ok(url) = std::env::var(&env_url_key) {
583 db_config.url = Some(url);
584 }
585
586 let mut mig_settings = MigrationSettings::default();
587 if let Some(m) = db.migrations {
588 if let Some(v) = m.locations {
589 mig_settings.locations =
590 v.into_iter().map(|s| normalize_location(&s)).collect();
591 }
592 apply_option!(m.table => mig_settings.table);
593 apply_option!(m.schema => mig_settings.schema);
594 apply_option!(m.out_of_order => mig_settings.out_of_order);
595 apply_option!(m.validate_on_migrate => mig_settings.validate_on_migrate);
596 apply_option!(m.clean_enabled => mig_settings.clean_enabled);
597 apply_option!(m.baseline_version => mig_settings.baseline_version);
598 apply_option_some!(m.installed_by => mig_settings.installed_by);
599 apply_option_some!(m.environment => mig_settings.environment);
600 apply_option!(m.dependency_ordering => mig_settings.dependency_ordering);
601 apply_option!(m.show_progress => mig_settings.show_progress);
602 apply_option!(m.batch_transaction => mig_settings.batch_transaction);
603 }
604
605 let mut hooks_config = HooksConfig::default();
606 if let Some(h) = db.hooks {
607 if let Some(v) = h.before_migrate {
608 hooks_config.before_migrate = v.into_iter().map(PathBuf::from).collect();
609 }
610 if let Some(v) = h.after_migrate {
611 hooks_config.after_migrate = v.into_iter().map(PathBuf::from).collect();
612 }
613 if let Some(v) = h.before_each_migrate {
614 hooks_config.before_each_migrate =
615 v.into_iter().map(PathBuf::from).collect();
616 }
617 if let Some(v) = h.after_each_migrate {
618 hooks_config.after_each_migrate =
619 v.into_iter().map(PathBuf::from).collect();
620 }
621 }
622
623 named_dbs.push(crate::multi::NamedDatabaseConfig {
624 name,
625 database: db_config,
626 migrations: mig_settings,
627 hooks: hooks_config,
628 placeholders: db.placeholders.unwrap_or_default(),
629 depends_on: db.depends_on.unwrap_or_default(),
630 });
631 }
632 self.multi_database = Some(named_dbs);
633 }
634 }
635
636 fn apply_env(&mut self) {
637 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_URL") {
638 self.database.url = Some(v);
639 }
640 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_HOST") {
641 self.database.host = Some(v);
642 }
643 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PORT") {
644 if let Ok(port) = v.parse::<u16>() {
645 self.database.port = Some(port);
646 }
647 }
648 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_USER") {
649 self.database.user = Some(v);
650 }
651 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PASSWORD") {
652 self.database.password = Some(v);
653 }
654 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_NAME") {
655 self.database.database = Some(v);
656 }
657 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_RETRIES") {
658 if let Ok(n) = v.parse::<u32>() {
659 self.database.connect_retries = n;
660 }
661 }
662 if let Ok(v) = std::env::var("WAYPOINT_SSL_MODE") {
663 if let Ok(mode) = v.parse() {
664 self.database.ssl_mode = mode;
665 }
666 }
667 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_TIMEOUT") {
668 if let Ok(n) = v.parse::<u32>() {
669 self.database.connect_timeout_secs = n;
670 }
671 }
672 if let Ok(v) = std::env::var("WAYPOINT_STATEMENT_TIMEOUT") {
673 if let Ok(n) = v.parse::<u32>() {
674 self.database.statement_timeout_secs = n;
675 }
676 }
677 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_LOCATIONS") {
678 self.migrations.locations =
679 v.split(',').map(|s| normalize_location(s.trim())).collect();
680 }
681 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_TABLE") {
682 self.migrations.table = v;
683 }
684 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_SCHEMA") {
685 self.migrations.schema = v;
686 }
687
688 if let Ok(v) = std::env::var("WAYPOINT_KEEPALIVE") {
689 if let Ok(n) = v.parse::<u32>() {
690 self.database.keepalive_secs = n;
691 }
692 }
693 if let Ok(v) = std::env::var("WAYPOINT_BATCH_TRANSACTION") {
694 self.migrations.batch_transaction = v == "1" || v.eq_ignore_ascii_case("true");
695 }
696 if let Ok(v) = std::env::var("WAYPOINT_ENVIRONMENT") {
697 self.migrations.environment = Some(v);
698 }
699
700 for (key, value) in std::env::vars() {
702 if let Some(placeholder_key) = key.strip_prefix("WAYPOINT_PLACEHOLDER_") {
703 self.placeholders
704 .insert(placeholder_key.to_lowercase(), value);
705 }
706 }
707 }
708
709 fn apply_cli(&mut self, overrides: &CliOverrides) {
710 apply_option_some_clone!(overrides.url => self.database.url);
711 apply_option_clone!(overrides.schema => self.migrations.schema);
712 apply_option_clone!(overrides.table => self.migrations.table);
713 apply_option_clone!(overrides.locations => self.migrations.locations);
714 apply_option!(overrides.out_of_order => self.migrations.out_of_order);
715 apply_option!(overrides.validate_on_migrate => self.migrations.validate_on_migrate);
716 apply_option_clone!(overrides.baseline_version => self.migrations.baseline_version);
717 apply_option!(overrides.connect_retries => self.database.connect_retries);
718 if let Some(ref v) = overrides.ssl_mode {
719 if let Ok(mode) = v.parse() {
721 self.database.ssl_mode = mode;
722 }
723 }
724 apply_option!(overrides.connect_timeout => self.database.connect_timeout_secs);
725 apply_option!(overrides.statement_timeout => self.database.statement_timeout_secs);
726 apply_option_some_clone!(overrides.environment => self.migrations.environment);
727 apply_option!(overrides.dependency_ordering => self.migrations.dependency_ordering);
728 apply_option!(overrides.keepalive => self.database.keepalive_secs);
729 apply_option!(overrides.batch_transaction => self.migrations.batch_transaction);
730 }
731
732 pub fn connection_string(&self) -> Result<String> {
737 if let Some(ref url) = self.database.url {
738 return Ok(normalize_jdbc_url(url));
739 }
740
741 let host = self.database.host.as_deref().unwrap_or("localhost");
742 let port = self.database.port.unwrap_or(5432);
743 let user =
744 self.database.user.as_deref().ok_or_else(|| {
745 WaypointError::ConfigError("Database user is required".to_string())
746 })?;
747 let database =
748 self.database.database.as_deref().ok_or_else(|| {
749 WaypointError::ConfigError("Database name is required".to_string())
750 })?;
751
752 let mut url = format!(
753 "host={} port={} user={} dbname={}",
754 host, port, user, database
755 );
756
757 if let Some(ref password) = self.database.password {
758 let escaped = password.replace('\\', "\\\\").replace('\'', "\\'");
760 url.push_str(&format!(" password='{}'", escaped));
761 }
762
763 Ok(url)
764 }
765}
766
767fn normalize_jdbc_url(url: &str) -> String {
774 let url = url.strip_prefix("jdbc:").unwrap_or(url);
776
777 if let Some((base, query)) = url.split_once('?') {
779 let mut user = None;
780 let mut password = None;
781 let mut other_params = Vec::new();
782
783 for param in query.split('&') {
784 if let Some((key, value)) = param.split_once('=') {
785 match key.to_lowercase().as_str() {
786 "user" => user = Some(value.to_string()),
787 "password" => password = Some(value.to_string()),
788 _ => other_params.push(param.to_string()),
789 }
790 }
791 }
792
793 if user.is_some() || password.is_some() {
795 if let Some(rest) = base
796 .strip_prefix("postgresql://")
797 .or_else(|| base.strip_prefix("postgres://"))
798 {
799 let scheme = if base.starts_with("postgresql://") {
800 "postgresql"
801 } else {
802 "postgres"
803 };
804
805 let auth = match (user, password) {
806 (Some(u), Some(p)) => format!("{}:{}@", u, p),
807 (Some(u), None) => format!("{}@", u),
808 (None, Some(p)) => format!(":{p}@"),
809 (None, None) => String::new(),
810 };
811
812 let mut result = format!("{}://{}{}", scheme, auth, rest);
813 if !other_params.is_empty() {
814 result.push('?');
815 result.push_str(&other_params.join("&"));
816 }
817 return result;
818 }
819 }
820
821 if other_params.is_empty() {
823 return base.to_string();
824 }
825 return format!("{}?{}", base, other_params.join("&"));
826 }
827
828 url.to_string()
829}
830
831pub fn normalize_location(location: &str) -> PathBuf {
833 let stripped = location.strip_prefix("filesystem:").unwrap_or(location);
834 PathBuf::from(stripped)
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 #[test]
842 fn test_default_config() {
843 let config = WaypointConfig::default();
844 assert_eq!(config.migrations.table, "waypoint_schema_history");
845 assert_eq!(config.migrations.schema, "public");
846 assert!(!config.migrations.out_of_order);
847 assert!(config.migrations.validate_on_migrate);
848 assert!(!config.migrations.clean_enabled);
849 assert_eq!(config.migrations.baseline_version, "1");
850 assert_eq!(
851 config.migrations.locations,
852 vec![PathBuf::from("db/migrations")]
853 );
854 }
855
856 #[test]
857 fn test_connection_string_from_url() {
858 let mut config = WaypointConfig::default();
859 config.database.url = Some("postgres://user:pass@localhost/db".to_string());
860 assert_eq!(
861 config.connection_string().unwrap(),
862 "postgres://user:pass@localhost/db"
863 );
864 }
865
866 #[test]
867 fn test_connection_string_from_fields() {
868 let mut config = WaypointConfig::default();
869 config.database.host = Some("myhost".to_string());
870 config.database.port = Some(5433);
871 config.database.user = Some("myuser".to_string());
872 config.database.database = Some("mydb".to_string());
873 config.database.password = Some("secret".to_string());
874
875 let conn = config.connection_string().unwrap();
876 assert!(conn.contains("host=myhost"));
877 assert!(conn.contains("port=5433"));
878 assert!(conn.contains("user=myuser"));
879 assert!(conn.contains("dbname=mydb"));
880 assert!(conn.contains("password='secret'"));
881 }
882
883 #[test]
884 fn test_connection_string_missing_user() {
885 let mut config = WaypointConfig::default();
886 config.database.database = Some("mydb".to_string());
887 assert!(config.connection_string().is_err());
888 }
889
890 #[test]
891 fn test_cli_overrides() {
892 let mut config = WaypointConfig::default();
893 let overrides = CliOverrides {
894 url: Some("postgres://override@localhost/db".to_string()),
895 schema: Some("custom_schema".to_string()),
896 table: Some("custom_table".to_string()),
897 locations: Some(vec![PathBuf::from("custom/path")]),
898 out_of_order: Some(true),
899 validate_on_migrate: Some(false),
900 baseline_version: Some("5".to_string()),
901 connect_retries: None,
902 ssl_mode: None,
903 connect_timeout: None,
904 statement_timeout: None,
905 environment: None,
906 dependency_ordering: None,
907 keepalive: None,
908 batch_transaction: None,
909 };
910
911 config.apply_cli(&overrides);
912
913 assert_eq!(
914 config.database.url.as_deref(),
915 Some("postgres://override@localhost/db")
916 );
917 assert_eq!(config.migrations.schema, "custom_schema");
918 assert_eq!(config.migrations.table, "custom_table");
919 assert_eq!(
920 config.migrations.locations,
921 vec![PathBuf::from("custom/path")]
922 );
923 assert!(config.migrations.out_of_order);
924 assert!(!config.migrations.validate_on_migrate);
925 assert_eq!(config.migrations.baseline_version, "5");
926 }
927
928 #[test]
929 fn test_toml_parsing() {
930 let toml_str = r#"
931[database]
932url = "postgres://user:pass@localhost/mydb"
933
934[migrations]
935table = "my_history"
936schema = "app"
937out_of_order = true
938locations = ["sql/migrations", "sql/seeds"]
939
940[placeholders]
941env = "production"
942app_name = "myapp"
943"#;
944
945 let toml_config: TomlConfig = toml::from_str(toml_str).unwrap();
946 let mut config = WaypointConfig::default();
947 config.apply_toml(toml_config);
948
949 assert_eq!(
950 config.database.url.as_deref(),
951 Some("postgres://user:pass@localhost/mydb")
952 );
953 assert_eq!(config.migrations.table, "my_history");
954 assert_eq!(config.migrations.schema, "app");
955 assert!(config.migrations.out_of_order);
956 assert_eq!(
957 config.migrations.locations,
958 vec![PathBuf::from("sql/migrations"), PathBuf::from("sql/seeds")]
959 );
960 assert_eq!(config.placeholders.get("env").unwrap(), "production");
961 assert_eq!(config.placeholders.get("app_name").unwrap(), "myapp");
962 }
963
964 #[test]
965 fn test_normalize_jdbc_url_with_credentials() {
966 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret";
967 assert_eq!(
968 normalize_jdbc_url(url),
969 "postgresql://admin:secret@myhost:5432/mydb"
970 );
971 }
972
973 #[test]
974 fn test_normalize_jdbc_url_user_only() {
975 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin";
976 assert_eq!(
977 normalize_jdbc_url(url),
978 "postgresql://admin@myhost:5432/mydb"
979 );
980 }
981
982 #[test]
983 fn test_normalize_jdbc_url_strips_jdbc_prefix() {
984 let url = "jdbc:postgresql://myhost:5432/mydb";
985 assert_eq!(normalize_jdbc_url(url), "postgresql://myhost:5432/mydb");
986 }
987
988 #[test]
989 fn test_normalize_jdbc_url_passthrough() {
990 let url = "postgresql://user:pass@myhost:5432/mydb";
991 assert_eq!(normalize_jdbc_url(url), url);
992 }
993
994 #[test]
995 fn test_normalize_jdbc_url_preserves_other_params() {
996 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret&sslmode=require";
997 assert_eq!(
998 normalize_jdbc_url(url),
999 "postgresql://admin:secret@myhost:5432/mydb?sslmode=require"
1000 );
1001 }
1002
1003 #[test]
1004 fn test_normalize_location_filesystem_prefix() {
1005 assert_eq!(
1006 normalize_location("filesystem:/flyway/sql"),
1007 PathBuf::from("/flyway/sql")
1008 );
1009 }
1010
1011 #[test]
1012 fn test_normalize_location_plain_path() {
1013 assert_eq!(
1014 normalize_location("/my/migrations"),
1015 PathBuf::from("/my/migrations")
1016 );
1017 }
1018
1019 #[test]
1020 fn test_normalize_location_relative() {
1021 assert_eq!(
1022 normalize_location("filesystem:db/migrations"),
1023 PathBuf::from("db/migrations")
1024 );
1025 }
1026
1027 #[test]
1028 fn test_connection_string_password_special_chars() {
1029 let config = WaypointConfig {
1030 database: DatabaseConfig {
1031 host: Some("localhost".to_string()),
1032 port: Some(5432),
1033 user: Some("admin".to_string()),
1034 database: Some("mydb".to_string()),
1035 password: Some("p@ss'w ord".to_string()),
1036 ..Default::default()
1037 },
1038 ..Default::default()
1039 };
1040 let conn = config.connection_string().unwrap();
1041 assert!(conn.contains("password='p@ss\\'w ord'"));
1042 }
1043}