1use std::collections::HashMap;
7use std::fmt;
8use std::path::PathBuf;
9
10use serde::Deserialize;
11
12use crate::error::{Result, WaypointError};
13
14#[derive(Debug, Clone, Default, PartialEq, Eq)]
16pub enum SslMode {
17 Disable,
19 #[default]
21 Prefer,
22 Require,
24}
25
26impl std::str::FromStr for SslMode {
27 type Err = WaypointError;
28
29 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30 match s.to_lowercase().as_str() {
31 "disable" | "disabled" => Ok(SslMode::Disable),
32 "prefer" => Ok(SslMode::Prefer),
33 "require" | "required" => Ok(SslMode::Require),
34 _ => Err(WaypointError::ConfigError(format!(
35 "Invalid SSL mode '{}'. Use 'disable', 'prefer', or 'require'.",
36 s
37 ))),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Default)]
44pub struct WaypointConfig {
45 pub database: DatabaseConfig,
46 pub migrations: MigrationSettings,
47 pub hooks: HooksConfig,
48 pub placeholders: HashMap<String, String>,
49}
50
51#[derive(Clone)]
53pub struct DatabaseConfig {
54 pub url: Option<String>,
55 pub host: Option<String>,
56 pub port: Option<u16>,
57 pub user: Option<String>,
58 pub password: Option<String>,
59 pub database: Option<String>,
60 pub connect_retries: u32,
61 pub ssl_mode: SslMode,
62 pub connect_timeout_secs: u32,
63 pub statement_timeout_secs: u32,
64}
65
66impl Default for DatabaseConfig {
67 fn default() -> Self {
68 Self {
69 url: None,
70 host: None,
71 port: None,
72 user: None,
73 password: None,
74 database: None,
75 connect_retries: 0,
76 ssl_mode: SslMode::Prefer,
77 connect_timeout_secs: 30,
78 statement_timeout_secs: 0,
79 }
80 }
81}
82
83impl fmt::Debug for DatabaseConfig {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("DatabaseConfig")
86 .field("url", &self.url.as_ref().map(|_| "[REDACTED]"))
87 .field("host", &self.host)
88 .field("port", &self.port)
89 .field("user", &self.user)
90 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
91 .field("database", &self.database)
92 .field("connect_retries", &self.connect_retries)
93 .field("ssl_mode", &self.ssl_mode)
94 .field("connect_timeout_secs", &self.connect_timeout_secs)
95 .field("statement_timeout_secs", &self.statement_timeout_secs)
96 .finish()
97 }
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct HooksConfig {
103 pub before_migrate: Vec<PathBuf>,
104 pub after_migrate: Vec<PathBuf>,
105 pub before_each_migrate: Vec<PathBuf>,
106 pub after_each_migrate: Vec<PathBuf>,
107}
108
109#[derive(Debug, Clone)]
111pub struct MigrationSettings {
112 pub locations: Vec<PathBuf>,
113 pub table: String,
114 pub schema: String,
115 pub out_of_order: bool,
116 pub validate_on_migrate: bool,
117 pub clean_enabled: bool,
118 pub baseline_version: String,
119 pub installed_by: Option<String>,
120}
121
122impl Default for MigrationSettings {
123 fn default() -> Self {
124 Self {
125 locations: vec![PathBuf::from("db/migrations")],
126 table: "waypoint_schema_history".to_string(),
127 schema: "public".to_string(),
128 out_of_order: false,
129 validate_on_migrate: true,
130 clean_enabled: false,
131 baseline_version: "1".to_string(),
132 installed_by: None,
133 }
134 }
135}
136
137#[derive(Deserialize, Default)]
140struct TomlConfig {
141 database: Option<TomlDatabaseConfig>,
142 migrations: Option<TomlMigrationSettings>,
143 hooks: Option<TomlHooksConfig>,
144 placeholders: Option<HashMap<String, String>>,
145}
146
147#[derive(Deserialize, Default)]
148struct TomlDatabaseConfig {
149 url: Option<String>,
150 host: Option<String>,
151 port: Option<u16>,
152 user: Option<String>,
153 password: Option<String>,
154 database: Option<String>,
155 connect_retries: Option<u32>,
156 ssl_mode: Option<String>,
157 connect_timeout: Option<u32>,
158 statement_timeout: Option<u32>,
159}
160
161#[derive(Deserialize, Default)]
162struct TomlMigrationSettings {
163 locations: Option<Vec<String>>,
164 table: Option<String>,
165 schema: Option<String>,
166 out_of_order: Option<bool>,
167 validate_on_migrate: Option<bool>,
168 clean_enabled: Option<bool>,
169 baseline_version: Option<String>,
170 installed_by: Option<String>,
171}
172
173#[derive(Deserialize, Default)]
174struct TomlHooksConfig {
175 before_migrate: Option<Vec<String>>,
176 after_migrate: Option<Vec<String>>,
177 before_each_migrate: Option<Vec<String>>,
178 after_each_migrate: Option<Vec<String>>,
179}
180
181#[derive(Debug, Default, Clone)]
183pub struct CliOverrides {
184 pub url: Option<String>,
185 pub schema: Option<String>,
186 pub table: Option<String>,
187 pub locations: Option<Vec<PathBuf>>,
188 pub out_of_order: Option<bool>,
189 pub validate_on_migrate: Option<bool>,
190 pub baseline_version: Option<String>,
191 pub connect_retries: Option<u32>,
192 pub ssl_mode: Option<String>,
193 pub connect_timeout: Option<u32>,
194 pub statement_timeout: Option<u32>,
195}
196
197impl WaypointConfig {
198 pub fn load(config_path: Option<&str>, overrides: &CliOverrides) -> Result<Self> {
204 let mut config = WaypointConfig::default();
205
206 let toml_path = config_path.unwrap_or("waypoint.toml");
208 if let Ok(content) = std::fs::read_to_string(toml_path) {
209 #[cfg(unix)]
211 {
212 use std::os::unix::fs::PermissionsExt;
213 if let Ok(meta) = std::fs::metadata(toml_path) {
214 let mode = meta.permissions().mode();
215 if mode & 0o077 != 0 {
216 tracing::warn!(
217 path = %toml_path,
218 mode = format!("{:o}", mode),
219 "Config file has overly permissive permissions. Consider chmod 600."
220 );
221 }
222 }
223 }
224 let toml_config: TomlConfig = toml::from_str(&content).map_err(|e| {
225 WaypointError::ConfigError(format!(
226 "Failed to parse config file '{}': {}",
227 toml_path, e
228 ))
229 })?;
230 config.apply_toml(toml_config);
231 } else if config_path.is_some() {
232 return Err(WaypointError::ConfigError(format!(
234 "Config file '{}' not found",
235 toml_path
236 )));
237 }
238
239 config.apply_env();
241
242 config.apply_cli(overrides);
244
245 crate::db::validate_identifier(&config.migrations.schema)?;
247 crate::db::validate_identifier(&config.migrations.table)?;
248
249 if config.database.connect_retries > 20 {
251 config.database.connect_retries = 20;
252 tracing::warn!("connect_retries capped at 20");
253 }
254
255 Ok(config)
256 }
257
258 fn apply_toml(&mut self, toml: TomlConfig) {
259 if let Some(db) = toml.database {
260 if let Some(v) = db.url {
261 self.database.url = Some(v);
262 }
263 if let Some(v) = db.host {
264 self.database.host = Some(v);
265 }
266 if let Some(v) = db.port {
267 self.database.port = Some(v);
268 }
269 if let Some(v) = db.user {
270 self.database.user = Some(v);
271 }
272 if let Some(v) = db.password {
273 self.database.password = Some(v);
274 }
275 if let Some(v) = db.database {
276 self.database.database = Some(v);
277 }
278 if let Some(v) = db.connect_retries {
279 self.database.connect_retries = v;
280 }
281 if let Some(v) = db.ssl_mode {
282 if let Ok(mode) = v.parse() {
283 self.database.ssl_mode = mode;
284 }
285 }
286 if let Some(v) = db.connect_timeout {
287 self.database.connect_timeout_secs = v;
288 }
289 if let Some(v) = db.statement_timeout {
290 self.database.statement_timeout_secs = v;
291 }
292 }
293
294 if let Some(m) = toml.migrations {
295 if let Some(v) = m.locations {
296 self.migrations.locations = v.into_iter().map(|s| normalize_location(&s)).collect();
297 }
298 if let Some(v) = m.table {
299 self.migrations.table = v;
300 }
301 if let Some(v) = m.schema {
302 self.migrations.schema = v;
303 }
304 if let Some(v) = m.out_of_order {
305 self.migrations.out_of_order = v;
306 }
307 if let Some(v) = m.validate_on_migrate {
308 self.migrations.validate_on_migrate = v;
309 }
310 if let Some(v) = m.clean_enabled {
311 self.migrations.clean_enabled = v;
312 }
313 if let Some(v) = m.baseline_version {
314 self.migrations.baseline_version = v;
315 }
316 if let Some(v) = m.installed_by {
317 self.migrations.installed_by = Some(v);
318 }
319 }
320
321 if let Some(h) = toml.hooks {
322 if let Some(v) = h.before_migrate {
323 self.hooks.before_migrate = v.into_iter().map(PathBuf::from).collect();
324 }
325 if let Some(v) = h.after_migrate {
326 self.hooks.after_migrate = v.into_iter().map(PathBuf::from).collect();
327 }
328 if let Some(v) = h.before_each_migrate {
329 self.hooks.before_each_migrate = v.into_iter().map(PathBuf::from).collect();
330 }
331 if let Some(v) = h.after_each_migrate {
332 self.hooks.after_each_migrate = v.into_iter().map(PathBuf::from).collect();
333 }
334 }
335
336 if let Some(p) = toml.placeholders {
337 self.placeholders.extend(p);
338 }
339 }
340
341 fn apply_env(&mut self) {
342 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_URL") {
343 self.database.url = Some(v);
344 }
345 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_HOST") {
346 self.database.host = Some(v);
347 }
348 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PORT") {
349 if let Ok(port) = v.parse::<u16>() {
350 self.database.port = Some(port);
351 }
352 }
353 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_USER") {
354 self.database.user = Some(v);
355 }
356 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PASSWORD") {
357 self.database.password = Some(v);
358 }
359 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_NAME") {
360 self.database.database = Some(v);
361 }
362 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_RETRIES") {
363 if let Ok(n) = v.parse::<u32>() {
364 self.database.connect_retries = n;
365 }
366 }
367 if let Ok(v) = std::env::var("WAYPOINT_SSL_MODE") {
368 if let Ok(mode) = v.parse() {
369 self.database.ssl_mode = mode;
370 }
371 }
372 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_TIMEOUT") {
373 if let Ok(n) = v.parse::<u32>() {
374 self.database.connect_timeout_secs = n;
375 }
376 }
377 if let Ok(v) = std::env::var("WAYPOINT_STATEMENT_TIMEOUT") {
378 if let Ok(n) = v.parse::<u32>() {
379 self.database.statement_timeout_secs = n;
380 }
381 }
382 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_LOCATIONS") {
383 self.migrations.locations =
384 v.split(',').map(|s| normalize_location(s.trim())).collect();
385 }
386 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_TABLE") {
387 self.migrations.table = v;
388 }
389 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_SCHEMA") {
390 self.migrations.schema = v;
391 }
392
393 for (key, value) in std::env::vars() {
395 if let Some(placeholder_key) = key.strip_prefix("WAYPOINT_PLACEHOLDER_") {
396 self.placeholders
397 .insert(placeholder_key.to_lowercase(), value);
398 }
399 }
400 }
401
402 fn apply_cli(&mut self, overrides: &CliOverrides) {
403 if let Some(ref v) = overrides.url {
404 self.database.url = Some(v.clone());
405 }
406 if let Some(ref v) = overrides.schema {
407 self.migrations.schema = v.clone();
408 }
409 if let Some(ref v) = overrides.table {
410 self.migrations.table = v.clone();
411 }
412 if let Some(ref v) = overrides.locations {
413 self.migrations.locations = v.clone();
414 }
415 if let Some(v) = overrides.out_of_order {
416 self.migrations.out_of_order = v;
417 }
418 if let Some(v) = overrides.validate_on_migrate {
419 self.migrations.validate_on_migrate = v;
420 }
421 if let Some(ref v) = overrides.baseline_version {
422 self.migrations.baseline_version = v.clone();
423 }
424 if let Some(v) = overrides.connect_retries {
425 self.database.connect_retries = v;
426 }
427 if let Some(ref v) = overrides.ssl_mode {
428 if let Ok(mode) = v.parse() {
430 self.database.ssl_mode = mode;
431 }
432 }
433 if let Some(v) = overrides.connect_timeout {
434 self.database.connect_timeout_secs = v;
435 }
436 if let Some(v) = overrides.statement_timeout {
437 self.database.statement_timeout_secs = v;
438 }
439 }
440
441 pub fn connection_string(&self) -> Result<String> {
446 if let Some(ref url) = self.database.url {
447 return Ok(normalize_jdbc_url(url));
448 }
449
450 let host = self.database.host.as_deref().unwrap_or("localhost");
451 let port = self.database.port.unwrap_or(5432);
452 let user =
453 self.database.user.as_deref().ok_or_else(|| {
454 WaypointError::ConfigError("Database user is required".to_string())
455 })?;
456 let database =
457 self.database.database.as_deref().ok_or_else(|| {
458 WaypointError::ConfigError("Database name is required".to_string())
459 })?;
460
461 let mut url = format!(
462 "host={} port={} user={} dbname={}",
463 host, port, user, database
464 );
465
466 if let Some(ref password) = self.database.password {
467 url.push_str(&format!(" password={}", password));
468 }
469
470 Ok(url)
471 }
472}
473
474fn normalize_jdbc_url(url: &str) -> String {
481 let url = url.strip_prefix("jdbc:").unwrap_or(url);
483
484 if let Some((base, query)) = url.split_once('?') {
486 let mut user = None;
487 let mut password = None;
488 let mut other_params = Vec::new();
489
490 for param in query.split('&') {
491 if let Some((key, value)) = param.split_once('=') {
492 match key.to_lowercase().as_str() {
493 "user" => user = Some(value.to_string()),
494 "password" => password = Some(value.to_string()),
495 _ => other_params.push(param.to_string()),
496 }
497 }
498 }
499
500 if user.is_some() || password.is_some() {
502 if let Some(rest) = base
503 .strip_prefix("postgresql://")
504 .or_else(|| base.strip_prefix("postgres://"))
505 {
506 let scheme = if base.starts_with("postgresql://") {
507 "postgresql"
508 } else {
509 "postgres"
510 };
511
512 let auth = match (user, password) {
513 (Some(u), Some(p)) => format!("{}:{}@", u, p),
514 (Some(u), None) => format!("{}@", u),
515 (None, Some(p)) => format!(":{p}@"),
516 (None, None) => String::new(),
517 };
518
519 let mut result = format!("{}://{}{}", scheme, auth, rest);
520 if !other_params.is_empty() {
521 result.push('?');
522 result.push_str(&other_params.join("&"));
523 }
524 return result;
525 }
526 }
527
528 if other_params.is_empty() {
530 return base.to_string();
531 }
532 return format!("{}?{}", base, other_params.join("&"));
533 }
534
535 url.to_string()
536}
537
538pub fn normalize_location(location: &str) -> PathBuf {
540 let stripped = location.strip_prefix("filesystem:").unwrap_or(location);
541 PathBuf::from(stripped)
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn test_default_config() {
550 let config = WaypointConfig::default();
551 assert_eq!(config.migrations.table, "waypoint_schema_history");
552 assert_eq!(config.migrations.schema, "public");
553 assert!(!config.migrations.out_of_order);
554 assert!(config.migrations.validate_on_migrate);
555 assert!(!config.migrations.clean_enabled);
556 assert_eq!(config.migrations.baseline_version, "1");
557 assert_eq!(
558 config.migrations.locations,
559 vec![PathBuf::from("db/migrations")]
560 );
561 }
562
563 #[test]
564 fn test_connection_string_from_url() {
565 let mut config = WaypointConfig::default();
566 config.database.url = Some("postgres://user:pass@localhost/db".to_string());
567 assert_eq!(
568 config.connection_string().unwrap(),
569 "postgres://user:pass@localhost/db"
570 );
571 }
572
573 #[test]
574 fn test_connection_string_from_fields() {
575 let mut config = WaypointConfig::default();
576 config.database.host = Some("myhost".to_string());
577 config.database.port = Some(5433);
578 config.database.user = Some("myuser".to_string());
579 config.database.database = Some("mydb".to_string());
580 config.database.password = Some("secret".to_string());
581
582 let conn = config.connection_string().unwrap();
583 assert!(conn.contains("host=myhost"));
584 assert!(conn.contains("port=5433"));
585 assert!(conn.contains("user=myuser"));
586 assert!(conn.contains("dbname=mydb"));
587 assert!(conn.contains("password=secret"));
588 }
589
590 #[test]
591 fn test_connection_string_missing_user() {
592 let mut config = WaypointConfig::default();
593 config.database.database = Some("mydb".to_string());
594 assert!(config.connection_string().is_err());
595 }
596
597 #[test]
598 fn test_cli_overrides() {
599 let mut config = WaypointConfig::default();
600 let overrides = CliOverrides {
601 url: Some("postgres://override@localhost/db".to_string()),
602 schema: Some("custom_schema".to_string()),
603 table: Some("custom_table".to_string()),
604 locations: Some(vec![PathBuf::from("custom/path")]),
605 out_of_order: Some(true),
606 validate_on_migrate: Some(false),
607 baseline_version: Some("5".to_string()),
608 connect_retries: None,
609 ssl_mode: None,
610 connect_timeout: None,
611 statement_timeout: None,
612 };
613
614 config.apply_cli(&overrides);
615
616 assert_eq!(
617 config.database.url.as_deref(),
618 Some("postgres://override@localhost/db")
619 );
620 assert_eq!(config.migrations.schema, "custom_schema");
621 assert_eq!(config.migrations.table, "custom_table");
622 assert_eq!(
623 config.migrations.locations,
624 vec![PathBuf::from("custom/path")]
625 );
626 assert!(config.migrations.out_of_order);
627 assert!(!config.migrations.validate_on_migrate);
628 assert_eq!(config.migrations.baseline_version, "5");
629 }
630
631 #[test]
632 fn test_toml_parsing() {
633 let toml_str = r#"
634[database]
635url = "postgres://user:pass@localhost/mydb"
636
637[migrations]
638table = "my_history"
639schema = "app"
640out_of_order = true
641locations = ["sql/migrations", "sql/seeds"]
642
643[placeholders]
644env = "production"
645app_name = "myapp"
646"#;
647
648 let toml_config: TomlConfig = toml::from_str(toml_str).unwrap();
649 let mut config = WaypointConfig::default();
650 config.apply_toml(toml_config);
651
652 assert_eq!(
653 config.database.url.as_deref(),
654 Some("postgres://user:pass@localhost/mydb")
655 );
656 assert_eq!(config.migrations.table, "my_history");
657 assert_eq!(config.migrations.schema, "app");
658 assert!(config.migrations.out_of_order);
659 assert_eq!(
660 config.migrations.locations,
661 vec![PathBuf::from("sql/migrations"), PathBuf::from("sql/seeds")]
662 );
663 assert_eq!(config.placeholders.get("env").unwrap(), "production");
664 assert_eq!(config.placeholders.get("app_name").unwrap(), "myapp");
665 }
666
667 #[test]
668 fn test_normalize_jdbc_url_with_credentials() {
669 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret";
670 assert_eq!(
671 normalize_jdbc_url(url),
672 "postgresql://admin:secret@myhost:5432/mydb"
673 );
674 }
675
676 #[test]
677 fn test_normalize_jdbc_url_user_only() {
678 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin";
679 assert_eq!(
680 normalize_jdbc_url(url),
681 "postgresql://admin@myhost:5432/mydb"
682 );
683 }
684
685 #[test]
686 fn test_normalize_jdbc_url_strips_jdbc_prefix() {
687 let url = "jdbc:postgresql://myhost:5432/mydb";
688 assert_eq!(normalize_jdbc_url(url), "postgresql://myhost:5432/mydb");
689 }
690
691 #[test]
692 fn test_normalize_jdbc_url_passthrough() {
693 let url = "postgresql://user:pass@myhost:5432/mydb";
694 assert_eq!(normalize_jdbc_url(url), url);
695 }
696
697 #[test]
698 fn test_normalize_jdbc_url_preserves_other_params() {
699 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret&sslmode=require";
700 assert_eq!(
701 normalize_jdbc_url(url),
702 "postgresql://admin:secret@myhost:5432/mydb?sslmode=require"
703 );
704 }
705
706 #[test]
707 fn test_normalize_location_filesystem_prefix() {
708 assert_eq!(
709 normalize_location("filesystem:/flyway/sql"),
710 PathBuf::from("/flyway/sql")
711 );
712 }
713
714 #[test]
715 fn test_normalize_location_plain_path() {
716 assert_eq!(
717 normalize_location("/my/migrations"),
718 PathBuf::from("/my/migrations")
719 );
720 }
721
722 #[test]
723 fn test_normalize_location_relative() {
724 assert_eq!(
725 normalize_location("filesystem:db/migrations"),
726 PathBuf::from("db/migrations")
727 );
728 }
729}