Skip to main content

waypoint_core/
config.rs

1//! Configuration loading and resolution.
2//!
3//! Supports TOML config files, environment variables, and CLI overrides
4//! with a defined priority order (CLI > env > TOML > defaults).
5
6use std::collections::HashMap;
7use std::fmt;
8use std::path::PathBuf;
9
10use serde::Deserialize;
11
12use crate::error::{Result, WaypointError};
13
14/// SSL/TLS connection mode.
15#[derive(Debug, Clone, Default, PartialEq, Eq)]
16pub enum SslMode {
17    /// Never use TLS (current default behavior).
18    Disable,
19    /// Try TLS first, fall back to plaintext.
20    #[default]
21    Prefer,
22    /// Require TLS — fail if handshake fails.
23    Require,
24}
25
26impl std::str::FromStr for SslMode {
27    type Err = WaypointError;
28
29    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30        match s.to_lowercase().as_str() {
31            "disable" | "disabled" => Ok(SslMode::Disable),
32            "prefer" => Ok(SslMode::Prefer),
33            "require" | "required" => Ok(SslMode::Require),
34            _ => Err(WaypointError::ConfigError(format!(
35                "Invalid SSL mode '{}'. Use 'disable', 'prefer', or 'require'.",
36                s
37            ))),
38        }
39    }
40}
41
42/// Top-level configuration for Waypoint.
43#[derive(Debug, Clone, Default)]
44pub struct WaypointConfig {
45    pub database: DatabaseConfig,
46    pub migrations: MigrationSettings,
47    pub hooks: HooksConfig,
48    pub placeholders: HashMap<String, String>,
49}
50
51/// Database connection configuration.
52#[derive(Clone)]
53pub struct DatabaseConfig {
54    pub url: Option<String>,
55    pub host: Option<String>,
56    pub port: Option<u16>,
57    pub user: Option<String>,
58    pub password: Option<String>,
59    pub database: Option<String>,
60    pub connect_retries: u32,
61    pub ssl_mode: SslMode,
62    pub connect_timeout_secs: u32,
63    pub statement_timeout_secs: u32,
64}
65
66impl Default for DatabaseConfig {
67    fn default() -> Self {
68        Self {
69            url: None,
70            host: None,
71            port: None,
72            user: None,
73            password: None,
74            database: None,
75            connect_retries: 0,
76            ssl_mode: SslMode::Prefer,
77            connect_timeout_secs: 30,
78            statement_timeout_secs: 0,
79        }
80    }
81}
82
83impl fmt::Debug for DatabaseConfig {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("DatabaseConfig")
86            .field("url", &self.url.as_ref().map(|_| "[REDACTED]"))
87            .field("host", &self.host)
88            .field("port", &self.port)
89            .field("user", &self.user)
90            .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
91            .field("database", &self.database)
92            .field("connect_retries", &self.connect_retries)
93            .field("ssl_mode", &self.ssl_mode)
94            .field("connect_timeout_secs", &self.connect_timeout_secs)
95            .field("statement_timeout_secs", &self.statement_timeout_secs)
96            .finish()
97    }
98}
99
100/// Hook configuration for running SQL before/after migrations.
101#[derive(Debug, Clone, Default)]
102pub struct HooksConfig {
103    pub before_migrate: Vec<PathBuf>,
104    pub after_migrate: Vec<PathBuf>,
105    pub before_each_migrate: Vec<PathBuf>,
106    pub after_each_migrate: Vec<PathBuf>,
107}
108
109/// Migration behavior settings.
110#[derive(Debug, Clone)]
111pub struct MigrationSettings {
112    pub locations: Vec<PathBuf>,
113    pub table: String,
114    pub schema: String,
115    pub out_of_order: bool,
116    pub validate_on_migrate: bool,
117    pub clean_enabled: bool,
118    pub baseline_version: String,
119    pub installed_by: Option<String>,
120}
121
122impl Default for MigrationSettings {
123    fn default() -> Self {
124        Self {
125            locations: vec![PathBuf::from("db/migrations")],
126            table: "waypoint_schema_history".to_string(),
127            schema: "public".to_string(),
128            out_of_order: false,
129            validate_on_migrate: true,
130            clean_enabled: false,
131            baseline_version: "1".to_string(),
132            installed_by: None,
133        }
134    }
135}
136
137// ── TOML deserialization structs ──
138
139#[derive(Deserialize, Default)]
140struct TomlConfig {
141    database: Option<TomlDatabaseConfig>,
142    migrations: Option<TomlMigrationSettings>,
143    hooks: Option<TomlHooksConfig>,
144    placeholders: Option<HashMap<String, String>>,
145}
146
147#[derive(Deserialize, Default)]
148struct TomlDatabaseConfig {
149    url: Option<String>,
150    host: Option<String>,
151    port: Option<u16>,
152    user: Option<String>,
153    password: Option<String>,
154    database: Option<String>,
155    connect_retries: Option<u32>,
156    ssl_mode: Option<String>,
157    connect_timeout: Option<u32>,
158    statement_timeout: Option<u32>,
159}
160
161#[derive(Deserialize, Default)]
162struct TomlMigrationSettings {
163    locations: Option<Vec<String>>,
164    table: Option<String>,
165    schema: Option<String>,
166    out_of_order: Option<bool>,
167    validate_on_migrate: Option<bool>,
168    clean_enabled: Option<bool>,
169    baseline_version: Option<String>,
170    installed_by: Option<String>,
171}
172
173#[derive(Deserialize, Default)]
174struct TomlHooksConfig {
175    before_migrate: Option<Vec<String>>,
176    after_migrate: Option<Vec<String>>,
177    before_each_migrate: Option<Vec<String>>,
178    after_each_migrate: Option<Vec<String>>,
179}
180
181/// CLI overrides that take highest priority.
182#[derive(Debug, Default, Clone)]
183pub struct CliOverrides {
184    pub url: Option<String>,
185    pub schema: Option<String>,
186    pub table: Option<String>,
187    pub locations: Option<Vec<PathBuf>>,
188    pub out_of_order: Option<bool>,
189    pub validate_on_migrate: Option<bool>,
190    pub baseline_version: Option<String>,
191    pub connect_retries: Option<u32>,
192    pub ssl_mode: Option<String>,
193    pub connect_timeout: Option<u32>,
194    pub statement_timeout: Option<u32>,
195}
196
197impl WaypointConfig {
198    /// Load configuration with the following priority (highest wins):
199    /// 1. CLI arguments
200    /// 2. Environment variables
201    /// 3. TOML config file
202    /// 4. Built-in defaults
203    pub fn load(config_path: Option<&str>, overrides: &CliOverrides) -> Result<Self> {
204        let mut config = WaypointConfig::default();
205
206        // Layer 3: TOML config file
207        let toml_path = config_path.unwrap_or("waypoint.toml");
208        if let Ok(content) = std::fs::read_to_string(toml_path) {
209            // Warn if config file has overly permissive permissions (Unix only)
210            #[cfg(unix)]
211            {
212                use std::os::unix::fs::PermissionsExt;
213                if let Ok(meta) = std::fs::metadata(toml_path) {
214                    let mode = meta.permissions().mode();
215                    if mode & 0o077 != 0 {
216                        tracing::warn!(
217                            path = %toml_path,
218                            mode = format!("{:o}", mode),
219                            "Config file has overly permissive permissions. Consider chmod 600."
220                        );
221                    }
222                }
223            }
224            let toml_config: TomlConfig = toml::from_str(&content).map_err(|e| {
225                WaypointError::ConfigError(format!(
226                    "Failed to parse config file '{}': {}",
227                    toml_path, e
228                ))
229            })?;
230            config.apply_toml(toml_config);
231        } else if config_path.is_some() {
232            // If explicitly specified, error if not found
233            return Err(WaypointError::ConfigError(format!(
234                "Config file '{}' not found",
235                toml_path
236            )));
237        }
238
239        // Layer 2: Environment variables
240        config.apply_env();
241
242        // Layer 1: CLI overrides
243        config.apply_cli(overrides);
244
245        // Validate identifiers
246        crate::db::validate_identifier(&config.migrations.schema)?;
247        crate::db::validate_identifier(&config.migrations.table)?;
248
249        // Cap connect_retries at 20
250        if config.database.connect_retries > 20 {
251            config.database.connect_retries = 20;
252            tracing::warn!("connect_retries capped at 20");
253        }
254
255        Ok(config)
256    }
257
258    fn apply_toml(&mut self, toml: TomlConfig) {
259        if let Some(db) = toml.database {
260            if let Some(v) = db.url {
261                self.database.url = Some(v);
262            }
263            if let Some(v) = db.host {
264                self.database.host = Some(v);
265            }
266            if let Some(v) = db.port {
267                self.database.port = Some(v);
268            }
269            if let Some(v) = db.user {
270                self.database.user = Some(v);
271            }
272            if let Some(v) = db.password {
273                self.database.password = Some(v);
274            }
275            if let Some(v) = db.database {
276                self.database.database = Some(v);
277            }
278            if let Some(v) = db.connect_retries {
279                self.database.connect_retries = v;
280            }
281            if let Some(v) = db.ssl_mode {
282                if let Ok(mode) = v.parse() {
283                    self.database.ssl_mode = mode;
284                }
285            }
286            if let Some(v) = db.connect_timeout {
287                self.database.connect_timeout_secs = v;
288            }
289            if let Some(v) = db.statement_timeout {
290                self.database.statement_timeout_secs = v;
291            }
292        }
293
294        if let Some(m) = toml.migrations {
295            if let Some(v) = m.locations {
296                self.migrations.locations = v.into_iter().map(|s| normalize_location(&s)).collect();
297            }
298            if let Some(v) = m.table {
299                self.migrations.table = v;
300            }
301            if let Some(v) = m.schema {
302                self.migrations.schema = v;
303            }
304            if let Some(v) = m.out_of_order {
305                self.migrations.out_of_order = v;
306            }
307            if let Some(v) = m.validate_on_migrate {
308                self.migrations.validate_on_migrate = v;
309            }
310            if let Some(v) = m.clean_enabled {
311                self.migrations.clean_enabled = v;
312            }
313            if let Some(v) = m.baseline_version {
314                self.migrations.baseline_version = v;
315            }
316            if let Some(v) = m.installed_by {
317                self.migrations.installed_by = Some(v);
318            }
319        }
320
321        if let Some(h) = toml.hooks {
322            if let Some(v) = h.before_migrate {
323                self.hooks.before_migrate = v.into_iter().map(PathBuf::from).collect();
324            }
325            if let Some(v) = h.after_migrate {
326                self.hooks.after_migrate = v.into_iter().map(PathBuf::from).collect();
327            }
328            if let Some(v) = h.before_each_migrate {
329                self.hooks.before_each_migrate = v.into_iter().map(PathBuf::from).collect();
330            }
331            if let Some(v) = h.after_each_migrate {
332                self.hooks.after_each_migrate = v.into_iter().map(PathBuf::from).collect();
333            }
334        }
335
336        if let Some(p) = toml.placeholders {
337            self.placeholders.extend(p);
338        }
339    }
340
341    fn apply_env(&mut self) {
342        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_URL") {
343            self.database.url = Some(v);
344        }
345        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_HOST") {
346            self.database.host = Some(v);
347        }
348        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PORT") {
349            if let Ok(port) = v.parse::<u16>() {
350                self.database.port = Some(port);
351            }
352        }
353        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_USER") {
354            self.database.user = Some(v);
355        }
356        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PASSWORD") {
357            self.database.password = Some(v);
358        }
359        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_NAME") {
360            self.database.database = Some(v);
361        }
362        if let Ok(v) = std::env::var("WAYPOINT_CONNECT_RETRIES") {
363            if let Ok(n) = v.parse::<u32>() {
364                self.database.connect_retries = n;
365            }
366        }
367        if let Ok(v) = std::env::var("WAYPOINT_SSL_MODE") {
368            if let Ok(mode) = v.parse() {
369                self.database.ssl_mode = mode;
370            }
371        }
372        if let Ok(v) = std::env::var("WAYPOINT_CONNECT_TIMEOUT") {
373            if let Ok(n) = v.parse::<u32>() {
374                self.database.connect_timeout_secs = n;
375            }
376        }
377        if let Ok(v) = std::env::var("WAYPOINT_STATEMENT_TIMEOUT") {
378            if let Ok(n) = v.parse::<u32>() {
379                self.database.statement_timeout_secs = n;
380            }
381        }
382        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_LOCATIONS") {
383            self.migrations.locations =
384                v.split(',').map(|s| normalize_location(s.trim())).collect();
385        }
386        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_TABLE") {
387            self.migrations.table = v;
388        }
389        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_SCHEMA") {
390            self.migrations.schema = v;
391        }
392
393        // Scan for placeholder env vars: WAYPOINT_PLACEHOLDER_{KEY}
394        for (key, value) in std::env::vars() {
395            if let Some(placeholder_key) = key.strip_prefix("WAYPOINT_PLACEHOLDER_") {
396                self.placeholders
397                    .insert(placeholder_key.to_lowercase(), value);
398            }
399        }
400    }
401
402    fn apply_cli(&mut self, overrides: &CliOverrides) {
403        if let Some(ref v) = overrides.url {
404            self.database.url = Some(v.clone());
405        }
406        if let Some(ref v) = overrides.schema {
407            self.migrations.schema = v.clone();
408        }
409        if let Some(ref v) = overrides.table {
410            self.migrations.table = v.clone();
411        }
412        if let Some(ref v) = overrides.locations {
413            self.migrations.locations = v.clone();
414        }
415        if let Some(v) = overrides.out_of_order {
416            self.migrations.out_of_order = v;
417        }
418        if let Some(v) = overrides.validate_on_migrate {
419            self.migrations.validate_on_migrate = v;
420        }
421        if let Some(ref v) = overrides.baseline_version {
422            self.migrations.baseline_version = v.clone();
423        }
424        if let Some(v) = overrides.connect_retries {
425            self.database.connect_retries = v;
426        }
427        if let Some(ref v) = overrides.ssl_mode {
428            // Ignore parse errors here — they'll be caught in validation
429            if let Ok(mode) = v.parse() {
430                self.database.ssl_mode = mode;
431            }
432        }
433        if let Some(v) = overrides.connect_timeout {
434            self.database.connect_timeout_secs = v;
435        }
436        if let Some(v) = overrides.statement_timeout {
437            self.database.statement_timeout_secs = v;
438        }
439    }
440
441    /// Build a connection string from the config.
442    /// Prefers `url` if set; otherwise builds from individual fields.
443    /// Handles JDBC-style URLs by stripping the `jdbc:` prefix and
444    /// extracting `user` and `password` query parameters.
445    pub fn connection_string(&self) -> Result<String> {
446        if let Some(ref url) = self.database.url {
447            return Ok(normalize_jdbc_url(url));
448        }
449
450        let host = self.database.host.as_deref().unwrap_or("localhost");
451        let port = self.database.port.unwrap_or(5432);
452        let user =
453            self.database.user.as_deref().ok_or_else(|| {
454                WaypointError::ConfigError("Database user is required".to_string())
455            })?;
456        let database =
457            self.database.database.as_deref().ok_or_else(|| {
458                WaypointError::ConfigError("Database name is required".to_string())
459            })?;
460
461        let mut url = format!(
462            "host={} port={} user={} dbname={}",
463            host, port, user, database
464        );
465
466        if let Some(ref password) = self.database.password {
467            url.push_str(&format!(" password={}", password));
468        }
469
470        Ok(url)
471    }
472}
473
474/// Normalize a JDBC-style URL to a standard PostgreSQL connection string.
475///
476/// Handles:
477///   - `jdbc:postgresql://host:port/db?user=x&password=y`  →  `postgresql://x:y@host:port/db`
478///   - `postgresql://...` passed through as-is
479///   - `postgres://...` passed through as-is
480fn normalize_jdbc_url(url: &str) -> String {
481    // Strip jdbc: prefix
482    let url = url.strip_prefix("jdbc:").unwrap_or(url);
483
484    // Parse query parameters for user/password if present
485    if let Some((base, query)) = url.split_once('?') {
486        let mut user = None;
487        let mut password = None;
488        let mut other_params = Vec::new();
489
490        for param in query.split('&') {
491            if let Some((key, value)) = param.split_once('=') {
492                match key.to_lowercase().as_str() {
493                    "user" => user = Some(value.to_string()),
494                    "password" => password = Some(value.to_string()),
495                    _ => other_params.push(param.to_string()),
496                }
497            }
498        }
499
500        // If we extracted user/password, rebuild the URL with credentials in the authority
501        if user.is_some() || password.is_some() {
502            if let Some(rest) = base
503                .strip_prefix("postgresql://")
504                .or_else(|| base.strip_prefix("postgres://"))
505            {
506                let scheme = if base.starts_with("postgresql://") {
507                    "postgresql"
508                } else {
509                    "postgres"
510                };
511
512                let auth = match (user, password) {
513                    (Some(u), Some(p)) => format!("{}:{}@", u, p),
514                    (Some(u), None) => format!("{}@", u),
515                    (None, Some(p)) => format!(":{p}@"),
516                    (None, None) => String::new(),
517                };
518
519                let mut result = format!("{}://{}{}", scheme, auth, rest);
520                if !other_params.is_empty() {
521                    result.push('?');
522                    result.push_str(&other_params.join("&"));
523                }
524                return result;
525            }
526        }
527
528        // No user/password in query, return with jdbc: stripped
529        if other_params.is_empty() {
530            return base.to_string();
531        }
532        return format!("{}?{}", base, other_params.join("&"));
533    }
534
535    url.to_string()
536}
537
538/// Strip `filesystem:` prefix from a location path (Flyway compatibility).
539pub fn normalize_location(location: &str) -> PathBuf {
540    let stripped = location.strip_prefix("filesystem:").unwrap_or(location);
541    PathBuf::from(stripped)
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_default_config() {
550        let config = WaypointConfig::default();
551        assert_eq!(config.migrations.table, "waypoint_schema_history");
552        assert_eq!(config.migrations.schema, "public");
553        assert!(!config.migrations.out_of_order);
554        assert!(config.migrations.validate_on_migrate);
555        assert!(!config.migrations.clean_enabled);
556        assert_eq!(config.migrations.baseline_version, "1");
557        assert_eq!(
558            config.migrations.locations,
559            vec![PathBuf::from("db/migrations")]
560        );
561    }
562
563    #[test]
564    fn test_connection_string_from_url() {
565        let mut config = WaypointConfig::default();
566        config.database.url = Some("postgres://user:pass@localhost/db".to_string());
567        assert_eq!(
568            config.connection_string().unwrap(),
569            "postgres://user:pass@localhost/db"
570        );
571    }
572
573    #[test]
574    fn test_connection_string_from_fields() {
575        let mut config = WaypointConfig::default();
576        config.database.host = Some("myhost".to_string());
577        config.database.port = Some(5433);
578        config.database.user = Some("myuser".to_string());
579        config.database.database = Some("mydb".to_string());
580        config.database.password = Some("secret".to_string());
581
582        let conn = config.connection_string().unwrap();
583        assert!(conn.contains("host=myhost"));
584        assert!(conn.contains("port=5433"));
585        assert!(conn.contains("user=myuser"));
586        assert!(conn.contains("dbname=mydb"));
587        assert!(conn.contains("password=secret"));
588    }
589
590    #[test]
591    fn test_connection_string_missing_user() {
592        let mut config = WaypointConfig::default();
593        config.database.database = Some("mydb".to_string());
594        assert!(config.connection_string().is_err());
595    }
596
597    #[test]
598    fn test_cli_overrides() {
599        let mut config = WaypointConfig::default();
600        let overrides = CliOverrides {
601            url: Some("postgres://override@localhost/db".to_string()),
602            schema: Some("custom_schema".to_string()),
603            table: Some("custom_table".to_string()),
604            locations: Some(vec![PathBuf::from("custom/path")]),
605            out_of_order: Some(true),
606            validate_on_migrate: Some(false),
607            baseline_version: Some("5".to_string()),
608            connect_retries: None,
609            ssl_mode: None,
610            connect_timeout: None,
611            statement_timeout: None,
612        };
613
614        config.apply_cli(&overrides);
615
616        assert_eq!(
617            config.database.url.as_deref(),
618            Some("postgres://override@localhost/db")
619        );
620        assert_eq!(config.migrations.schema, "custom_schema");
621        assert_eq!(config.migrations.table, "custom_table");
622        assert_eq!(
623            config.migrations.locations,
624            vec![PathBuf::from("custom/path")]
625        );
626        assert!(config.migrations.out_of_order);
627        assert!(!config.migrations.validate_on_migrate);
628        assert_eq!(config.migrations.baseline_version, "5");
629    }
630
631    #[test]
632    fn test_toml_parsing() {
633        let toml_str = r#"
634[database]
635url = "postgres://user:pass@localhost/mydb"
636
637[migrations]
638table = "my_history"
639schema = "app"
640out_of_order = true
641locations = ["sql/migrations", "sql/seeds"]
642
643[placeholders]
644env = "production"
645app_name = "myapp"
646"#;
647
648        let toml_config: TomlConfig = toml::from_str(toml_str).unwrap();
649        let mut config = WaypointConfig::default();
650        config.apply_toml(toml_config);
651
652        assert_eq!(
653            config.database.url.as_deref(),
654            Some("postgres://user:pass@localhost/mydb")
655        );
656        assert_eq!(config.migrations.table, "my_history");
657        assert_eq!(config.migrations.schema, "app");
658        assert!(config.migrations.out_of_order);
659        assert_eq!(
660            config.migrations.locations,
661            vec![PathBuf::from("sql/migrations"), PathBuf::from("sql/seeds")]
662        );
663        assert_eq!(config.placeholders.get("env").unwrap(), "production");
664        assert_eq!(config.placeholders.get("app_name").unwrap(), "myapp");
665    }
666
667    #[test]
668    fn test_normalize_jdbc_url_with_credentials() {
669        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret";
670        assert_eq!(
671            normalize_jdbc_url(url),
672            "postgresql://admin:secret@myhost:5432/mydb"
673        );
674    }
675
676    #[test]
677    fn test_normalize_jdbc_url_user_only() {
678        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin";
679        assert_eq!(
680            normalize_jdbc_url(url),
681            "postgresql://admin@myhost:5432/mydb"
682        );
683    }
684
685    #[test]
686    fn test_normalize_jdbc_url_strips_jdbc_prefix() {
687        let url = "jdbc:postgresql://myhost:5432/mydb";
688        assert_eq!(normalize_jdbc_url(url), "postgresql://myhost:5432/mydb");
689    }
690
691    #[test]
692    fn test_normalize_jdbc_url_passthrough() {
693        let url = "postgresql://user:pass@myhost:5432/mydb";
694        assert_eq!(normalize_jdbc_url(url), url);
695    }
696
697    #[test]
698    fn test_normalize_jdbc_url_preserves_other_params() {
699        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret&sslmode=require";
700        assert_eq!(
701            normalize_jdbc_url(url),
702            "postgresql://admin:secret@myhost:5432/mydb?sslmode=require"
703        );
704    }
705
706    #[test]
707    fn test_normalize_location_filesystem_prefix() {
708        assert_eq!(
709            normalize_location("filesystem:/flyway/sql"),
710            PathBuf::from("/flyway/sql")
711        );
712    }
713
714    #[test]
715    fn test_normalize_location_plain_path() {
716        assert_eq!(
717            normalize_location("/my/migrations"),
718            PathBuf::from("/my/migrations")
719        );
720    }
721
722    #[test]
723    fn test_normalize_location_relative() {
724        assert_eq!(
725            normalize_location("filesystem:db/migrations"),
726            PathBuf::from("db/migrations")
727        );
728    }
729}