1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4
5use serde::Deserialize;
6
7use crate::error::{Result, WaypointError};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq)]
11pub enum SslMode {
12 Disable,
14 #[default]
16 Prefer,
17 Require,
19}
20
21impl std::str::FromStr for SslMode {
22 type Err = WaypointError;
23
24 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
25 match s.to_lowercase().as_str() {
26 "disable" | "disabled" => Ok(SslMode::Disable),
27 "prefer" => Ok(SslMode::Prefer),
28 "require" | "required" => Ok(SslMode::Require),
29 _ => Err(WaypointError::ConfigError(format!(
30 "Invalid SSL mode '{}'. Use 'disable', 'prefer', or 'require'.",
31 s
32 ))),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct WaypointConfig {
40 pub database: DatabaseConfig,
41 pub migrations: MigrationSettings,
42 pub hooks: HooksConfig,
43 pub placeholders: HashMap<String, String>,
44}
45
46#[derive(Clone)]
48pub struct DatabaseConfig {
49 pub url: Option<String>,
50 pub host: Option<String>,
51 pub port: Option<u16>,
52 pub user: Option<String>,
53 pub password: Option<String>,
54 pub database: Option<String>,
55 pub connect_retries: u32,
56 pub ssl_mode: SslMode,
57 pub connect_timeout_secs: u32,
58 pub statement_timeout_secs: u32,
59}
60
61impl Default for DatabaseConfig {
62 fn default() -> Self {
63 Self {
64 url: None,
65 host: None,
66 port: None,
67 user: None,
68 password: None,
69 database: None,
70 connect_retries: 0,
71 ssl_mode: SslMode::Prefer,
72 connect_timeout_secs: 30,
73 statement_timeout_secs: 0,
74 }
75 }
76}
77
78impl fmt::Debug for DatabaseConfig {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 f.debug_struct("DatabaseConfig")
81 .field("url", &self.url.as_ref().map(|_| "[REDACTED]"))
82 .field("host", &self.host)
83 .field("port", &self.port)
84 .field("user", &self.user)
85 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
86 .field("database", &self.database)
87 .field("connect_retries", &self.connect_retries)
88 .field("ssl_mode", &self.ssl_mode)
89 .field("connect_timeout_secs", &self.connect_timeout_secs)
90 .field("statement_timeout_secs", &self.statement_timeout_secs)
91 .finish()
92 }
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct HooksConfig {
98 pub before_migrate: Vec<PathBuf>,
99 pub after_migrate: Vec<PathBuf>,
100 pub before_each_migrate: Vec<PathBuf>,
101 pub after_each_migrate: Vec<PathBuf>,
102}
103
104#[derive(Debug, Clone)]
106pub struct MigrationSettings {
107 pub locations: Vec<PathBuf>,
108 pub table: String,
109 pub schema: String,
110 pub out_of_order: bool,
111 pub validate_on_migrate: bool,
112 pub clean_enabled: bool,
113 pub baseline_version: String,
114 pub installed_by: Option<String>,
115}
116
117impl Default for MigrationSettings {
118 fn default() -> Self {
119 Self {
120 locations: vec![PathBuf::from("db/migrations")],
121 table: "waypoint_schema_history".to_string(),
122 schema: "public".to_string(),
123 out_of_order: false,
124 validate_on_migrate: true,
125 clean_enabled: false,
126 baseline_version: "1".to_string(),
127 installed_by: None,
128 }
129 }
130}
131
132#[derive(Deserialize, Default)]
135struct TomlConfig {
136 database: Option<TomlDatabaseConfig>,
137 migrations: Option<TomlMigrationSettings>,
138 hooks: Option<TomlHooksConfig>,
139 placeholders: Option<HashMap<String, String>>,
140}
141
142#[derive(Deserialize, Default)]
143struct TomlDatabaseConfig {
144 url: Option<String>,
145 host: Option<String>,
146 port: Option<u16>,
147 user: Option<String>,
148 password: Option<String>,
149 database: Option<String>,
150 connect_retries: Option<u32>,
151 ssl_mode: Option<String>,
152 connect_timeout: Option<u32>,
153 statement_timeout: Option<u32>,
154}
155
156#[derive(Deserialize, Default)]
157struct TomlMigrationSettings {
158 locations: Option<Vec<String>>,
159 table: Option<String>,
160 schema: Option<String>,
161 out_of_order: Option<bool>,
162 validate_on_migrate: Option<bool>,
163 clean_enabled: Option<bool>,
164 baseline_version: Option<String>,
165 installed_by: Option<String>,
166}
167
168#[derive(Deserialize, Default)]
169struct TomlHooksConfig {
170 before_migrate: Option<Vec<String>>,
171 after_migrate: Option<Vec<String>>,
172 before_each_migrate: Option<Vec<String>>,
173 after_each_migrate: Option<Vec<String>>,
174}
175
176#[derive(Debug, Default, Clone)]
178pub struct CliOverrides {
179 pub url: Option<String>,
180 pub schema: Option<String>,
181 pub table: Option<String>,
182 pub locations: Option<Vec<PathBuf>>,
183 pub out_of_order: Option<bool>,
184 pub validate_on_migrate: Option<bool>,
185 pub baseline_version: Option<String>,
186 pub connect_retries: Option<u32>,
187 pub ssl_mode: Option<String>,
188 pub connect_timeout: Option<u32>,
189 pub statement_timeout: Option<u32>,
190}
191
192impl WaypointConfig {
193 pub fn load(config_path: Option<&str>, overrides: &CliOverrides) -> Result<Self> {
199 let mut config = WaypointConfig::default();
200
201 let toml_path = config_path.unwrap_or("waypoint.toml");
203 if let Ok(content) = std::fs::read_to_string(toml_path) {
204 #[cfg(unix)]
206 {
207 use std::os::unix::fs::PermissionsExt;
208 if let Ok(meta) = std::fs::metadata(toml_path) {
209 let mode = meta.permissions().mode();
210 if mode & 0o077 != 0 {
211 tracing::warn!(
212 path = %toml_path,
213 mode = format!("{:o}", mode),
214 "Config file has overly permissive permissions. Consider chmod 600."
215 );
216 }
217 }
218 }
219 let toml_config: TomlConfig = toml::from_str(&content).map_err(|e| {
220 WaypointError::ConfigError(format!(
221 "Failed to parse config file '{}': {}",
222 toml_path, e
223 ))
224 })?;
225 config.apply_toml(toml_config);
226 } else if config_path.is_some() {
227 return Err(WaypointError::ConfigError(format!(
229 "Config file '{}' not found",
230 toml_path
231 )));
232 }
233
234 config.apply_env();
236
237 config.apply_cli(overrides);
239
240 crate::db::validate_identifier(&config.migrations.schema)?;
242 crate::db::validate_identifier(&config.migrations.table)?;
243
244 if config.database.connect_retries > 20 {
246 config.database.connect_retries = 20;
247 tracing::warn!("connect_retries capped at 20");
248 }
249
250 Ok(config)
251 }
252
253 fn apply_toml(&mut self, toml: TomlConfig) {
254 if let Some(db) = toml.database {
255 if let Some(v) = db.url {
256 self.database.url = Some(v);
257 }
258 if let Some(v) = db.host {
259 self.database.host = Some(v);
260 }
261 if let Some(v) = db.port {
262 self.database.port = Some(v);
263 }
264 if let Some(v) = db.user {
265 self.database.user = Some(v);
266 }
267 if let Some(v) = db.password {
268 self.database.password = Some(v);
269 }
270 if let Some(v) = db.database {
271 self.database.database = Some(v);
272 }
273 if let Some(v) = db.connect_retries {
274 self.database.connect_retries = v;
275 }
276 if let Some(v) = db.ssl_mode {
277 if let Ok(mode) = v.parse() {
278 self.database.ssl_mode = mode;
279 }
280 }
281 if let Some(v) = db.connect_timeout {
282 self.database.connect_timeout_secs = v;
283 }
284 if let Some(v) = db.statement_timeout {
285 self.database.statement_timeout_secs = v;
286 }
287 }
288
289 if let Some(m) = toml.migrations {
290 if let Some(v) = m.locations {
291 self.migrations.locations = v.into_iter().map(|s| normalize_location(&s)).collect();
292 }
293 if let Some(v) = m.table {
294 self.migrations.table = v;
295 }
296 if let Some(v) = m.schema {
297 self.migrations.schema = v;
298 }
299 if let Some(v) = m.out_of_order {
300 self.migrations.out_of_order = v;
301 }
302 if let Some(v) = m.validate_on_migrate {
303 self.migrations.validate_on_migrate = v;
304 }
305 if let Some(v) = m.clean_enabled {
306 self.migrations.clean_enabled = v;
307 }
308 if let Some(v) = m.baseline_version {
309 self.migrations.baseline_version = v;
310 }
311 if let Some(v) = m.installed_by {
312 self.migrations.installed_by = Some(v);
313 }
314 }
315
316 if let Some(h) = toml.hooks {
317 if let Some(v) = h.before_migrate {
318 self.hooks.before_migrate = v.into_iter().map(PathBuf::from).collect();
319 }
320 if let Some(v) = h.after_migrate {
321 self.hooks.after_migrate = v.into_iter().map(PathBuf::from).collect();
322 }
323 if let Some(v) = h.before_each_migrate {
324 self.hooks.before_each_migrate = v.into_iter().map(PathBuf::from).collect();
325 }
326 if let Some(v) = h.after_each_migrate {
327 self.hooks.after_each_migrate = v.into_iter().map(PathBuf::from).collect();
328 }
329 }
330
331 if let Some(p) = toml.placeholders {
332 self.placeholders.extend(p);
333 }
334 }
335
336 fn apply_env(&mut self) {
337 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_URL") {
338 self.database.url = Some(v);
339 }
340 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_HOST") {
341 self.database.host = Some(v);
342 }
343 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PORT") {
344 if let Ok(port) = v.parse::<u16>() {
345 self.database.port = Some(port);
346 }
347 }
348 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_USER") {
349 self.database.user = Some(v);
350 }
351 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PASSWORD") {
352 self.database.password = Some(v);
353 }
354 if let Ok(v) = std::env::var("WAYPOINT_DATABASE_NAME") {
355 self.database.database = Some(v);
356 }
357 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_RETRIES") {
358 if let Ok(n) = v.parse::<u32>() {
359 self.database.connect_retries = n;
360 }
361 }
362 if let Ok(v) = std::env::var("WAYPOINT_SSL_MODE") {
363 if let Ok(mode) = v.parse() {
364 self.database.ssl_mode = mode;
365 }
366 }
367 if let Ok(v) = std::env::var("WAYPOINT_CONNECT_TIMEOUT") {
368 if let Ok(n) = v.parse::<u32>() {
369 self.database.connect_timeout_secs = n;
370 }
371 }
372 if let Ok(v) = std::env::var("WAYPOINT_STATEMENT_TIMEOUT") {
373 if let Ok(n) = v.parse::<u32>() {
374 self.database.statement_timeout_secs = n;
375 }
376 }
377 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_LOCATIONS") {
378 self.migrations.locations =
379 v.split(',').map(|s| normalize_location(s.trim())).collect();
380 }
381 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_TABLE") {
382 self.migrations.table = v;
383 }
384 if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_SCHEMA") {
385 self.migrations.schema = v;
386 }
387
388 for (key, value) in std::env::vars() {
390 if let Some(placeholder_key) = key.strip_prefix("WAYPOINT_PLACEHOLDER_") {
391 self.placeholders
392 .insert(placeholder_key.to_lowercase(), value);
393 }
394 }
395 }
396
397 fn apply_cli(&mut self, overrides: &CliOverrides) {
398 if let Some(ref v) = overrides.url {
399 self.database.url = Some(v.clone());
400 }
401 if let Some(ref v) = overrides.schema {
402 self.migrations.schema = v.clone();
403 }
404 if let Some(ref v) = overrides.table {
405 self.migrations.table = v.clone();
406 }
407 if let Some(ref v) = overrides.locations {
408 self.migrations.locations = v.clone();
409 }
410 if let Some(v) = overrides.out_of_order {
411 self.migrations.out_of_order = v;
412 }
413 if let Some(v) = overrides.validate_on_migrate {
414 self.migrations.validate_on_migrate = v;
415 }
416 if let Some(ref v) = overrides.baseline_version {
417 self.migrations.baseline_version = v.clone();
418 }
419 if let Some(v) = overrides.connect_retries {
420 self.database.connect_retries = v;
421 }
422 if let Some(ref v) = overrides.ssl_mode {
423 if let Ok(mode) = v.parse() {
425 self.database.ssl_mode = mode;
426 }
427 }
428 if let Some(v) = overrides.connect_timeout {
429 self.database.connect_timeout_secs = v;
430 }
431 if let Some(v) = overrides.statement_timeout {
432 self.database.statement_timeout_secs = v;
433 }
434 }
435
436 pub fn connection_string(&self) -> Result<String> {
441 if let Some(ref url) = self.database.url {
442 return Ok(normalize_jdbc_url(url));
443 }
444
445 let host = self.database.host.as_deref().unwrap_or("localhost");
446 let port = self.database.port.unwrap_or(5432);
447 let user =
448 self.database.user.as_deref().ok_or_else(|| {
449 WaypointError::ConfigError("Database user is required".to_string())
450 })?;
451 let database =
452 self.database.database.as_deref().ok_or_else(|| {
453 WaypointError::ConfigError("Database name is required".to_string())
454 })?;
455
456 let mut url = format!(
457 "host={} port={} user={} dbname={}",
458 host, port, user, database
459 );
460
461 if let Some(ref password) = self.database.password {
462 url.push_str(&format!(" password={}", password));
463 }
464
465 Ok(url)
466 }
467}
468
469fn normalize_jdbc_url(url: &str) -> String {
476 let url = url.strip_prefix("jdbc:").unwrap_or(url);
478
479 if let Some((base, query)) = url.split_once('?') {
481 let mut user = None;
482 let mut password = None;
483 let mut other_params = Vec::new();
484
485 for param in query.split('&') {
486 if let Some((key, value)) = param.split_once('=') {
487 match key.to_lowercase().as_str() {
488 "user" => user = Some(value.to_string()),
489 "password" => password = Some(value.to_string()),
490 _ => other_params.push(param.to_string()),
491 }
492 }
493 }
494
495 if user.is_some() || password.is_some() {
497 if let Some(rest) = base
498 .strip_prefix("postgresql://")
499 .or_else(|| base.strip_prefix("postgres://"))
500 {
501 let scheme = if base.starts_with("postgresql://") {
502 "postgresql"
503 } else {
504 "postgres"
505 };
506
507 let auth = match (user, password) {
508 (Some(u), Some(p)) => format!("{}:{}@", u, p),
509 (Some(u), None) => format!("{}@", u),
510 (None, Some(p)) => format!(":{p}@"),
511 (None, None) => String::new(),
512 };
513
514 let mut result = format!("{}://{}{}", scheme, auth, rest);
515 if !other_params.is_empty() {
516 result.push('?');
517 result.push_str(&other_params.join("&"));
518 }
519 return result;
520 }
521 }
522
523 if other_params.is_empty() {
525 return base.to_string();
526 }
527 return format!("{}?{}", base, other_params.join("&"));
528 }
529
530 url.to_string()
531}
532
533pub fn normalize_location(location: &str) -> PathBuf {
535 let stripped = location.strip_prefix("filesystem:").unwrap_or(location);
536 PathBuf::from(stripped)
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_default_config() {
545 let config = WaypointConfig::default();
546 assert_eq!(config.migrations.table, "waypoint_schema_history");
547 assert_eq!(config.migrations.schema, "public");
548 assert!(!config.migrations.out_of_order);
549 assert!(config.migrations.validate_on_migrate);
550 assert!(!config.migrations.clean_enabled);
551 assert_eq!(config.migrations.baseline_version, "1");
552 assert_eq!(
553 config.migrations.locations,
554 vec![PathBuf::from("db/migrations")]
555 );
556 }
557
558 #[test]
559 fn test_connection_string_from_url() {
560 let mut config = WaypointConfig::default();
561 config.database.url = Some("postgres://user:pass@localhost/db".to_string());
562 assert_eq!(
563 config.connection_string().unwrap(),
564 "postgres://user:pass@localhost/db"
565 );
566 }
567
568 #[test]
569 fn test_connection_string_from_fields() {
570 let mut config = WaypointConfig::default();
571 config.database.host = Some("myhost".to_string());
572 config.database.port = Some(5433);
573 config.database.user = Some("myuser".to_string());
574 config.database.database = Some("mydb".to_string());
575 config.database.password = Some("secret".to_string());
576
577 let conn = config.connection_string().unwrap();
578 assert!(conn.contains("host=myhost"));
579 assert!(conn.contains("port=5433"));
580 assert!(conn.contains("user=myuser"));
581 assert!(conn.contains("dbname=mydb"));
582 assert!(conn.contains("password=secret"));
583 }
584
585 #[test]
586 fn test_connection_string_missing_user() {
587 let mut config = WaypointConfig::default();
588 config.database.database = Some("mydb".to_string());
589 assert!(config.connection_string().is_err());
590 }
591
592 #[test]
593 fn test_cli_overrides() {
594 let mut config = WaypointConfig::default();
595 let overrides = CliOverrides {
596 url: Some("postgres://override@localhost/db".to_string()),
597 schema: Some("custom_schema".to_string()),
598 table: Some("custom_table".to_string()),
599 locations: Some(vec![PathBuf::from("custom/path")]),
600 out_of_order: Some(true),
601 validate_on_migrate: Some(false),
602 baseline_version: Some("5".to_string()),
603 connect_retries: None,
604 ssl_mode: None,
605 connect_timeout: None,
606 statement_timeout: None,
607 };
608
609 config.apply_cli(&overrides);
610
611 assert_eq!(
612 config.database.url.as_deref(),
613 Some("postgres://override@localhost/db")
614 );
615 assert_eq!(config.migrations.schema, "custom_schema");
616 assert_eq!(config.migrations.table, "custom_table");
617 assert_eq!(
618 config.migrations.locations,
619 vec![PathBuf::from("custom/path")]
620 );
621 assert!(config.migrations.out_of_order);
622 assert!(!config.migrations.validate_on_migrate);
623 assert_eq!(config.migrations.baseline_version, "5");
624 }
625
626 #[test]
627 fn test_toml_parsing() {
628 let toml_str = r#"
629[database]
630url = "postgres://user:pass@localhost/mydb"
631
632[migrations]
633table = "my_history"
634schema = "app"
635out_of_order = true
636locations = ["sql/migrations", "sql/seeds"]
637
638[placeholders]
639env = "production"
640app_name = "myapp"
641"#;
642
643 let toml_config: TomlConfig = toml::from_str(toml_str).unwrap();
644 let mut config = WaypointConfig::default();
645 config.apply_toml(toml_config);
646
647 assert_eq!(
648 config.database.url.as_deref(),
649 Some("postgres://user:pass@localhost/mydb")
650 );
651 assert_eq!(config.migrations.table, "my_history");
652 assert_eq!(config.migrations.schema, "app");
653 assert!(config.migrations.out_of_order);
654 assert_eq!(
655 config.migrations.locations,
656 vec![PathBuf::from("sql/migrations"), PathBuf::from("sql/seeds")]
657 );
658 assert_eq!(config.placeholders.get("env").unwrap(), "production");
659 assert_eq!(config.placeholders.get("app_name").unwrap(), "myapp");
660 }
661
662 #[test]
663 fn test_normalize_jdbc_url_with_credentials() {
664 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret";
665 assert_eq!(
666 normalize_jdbc_url(url),
667 "postgresql://admin:secret@myhost:5432/mydb"
668 );
669 }
670
671 #[test]
672 fn test_normalize_jdbc_url_user_only() {
673 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin";
674 assert_eq!(
675 normalize_jdbc_url(url),
676 "postgresql://admin@myhost:5432/mydb"
677 );
678 }
679
680 #[test]
681 fn test_normalize_jdbc_url_strips_jdbc_prefix() {
682 let url = "jdbc:postgresql://myhost:5432/mydb";
683 assert_eq!(normalize_jdbc_url(url), "postgresql://myhost:5432/mydb");
684 }
685
686 #[test]
687 fn test_normalize_jdbc_url_passthrough() {
688 let url = "postgresql://user:pass@myhost:5432/mydb";
689 assert_eq!(normalize_jdbc_url(url), url);
690 }
691
692 #[test]
693 fn test_normalize_jdbc_url_preserves_other_params() {
694 let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret&sslmode=require";
695 assert_eq!(
696 normalize_jdbc_url(url),
697 "postgresql://admin:secret@myhost:5432/mydb?sslmode=require"
698 );
699 }
700
701 #[test]
702 fn test_normalize_location_filesystem_prefix() {
703 assert_eq!(
704 normalize_location("filesystem:/flyway/sql"),
705 PathBuf::from("/flyway/sql")
706 );
707 }
708
709 #[test]
710 fn test_normalize_location_plain_path() {
711 assert_eq!(
712 normalize_location("/my/migrations"),
713 PathBuf::from("/my/migrations")
714 );
715 }
716
717 #[test]
718 fn test_normalize_location_relative() {
719 assert_eq!(
720 normalize_location("filesystem:db/migrations"),
721 PathBuf::from("db/migrations")
722 );
723 }
724}