Skip to main content

waypoint_core/
config.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4
5use serde::Deserialize;
6
7use crate::error::{Result, WaypointError};
8
9/// SSL/TLS connection mode.
10#[derive(Debug, Clone, Default, PartialEq, Eq)]
11pub enum SslMode {
12    /// Never use TLS (current default behavior).
13    Disable,
14    /// Try TLS first, fall back to plaintext.
15    #[default]
16    Prefer,
17    /// Require TLS — fail if handshake fails.
18    Require,
19}
20
21impl std::str::FromStr for SslMode {
22    type Err = WaypointError;
23
24    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
25        match s.to_lowercase().as_str() {
26            "disable" | "disabled" => Ok(SslMode::Disable),
27            "prefer" => Ok(SslMode::Prefer),
28            "require" | "required" => Ok(SslMode::Require),
29            _ => Err(WaypointError::ConfigError(format!(
30                "Invalid SSL mode '{}'. Use 'disable', 'prefer', or 'require'.",
31                s
32            ))),
33        }
34    }
35}
36
37/// Top-level configuration for Waypoint.
38#[derive(Debug, Clone, Default)]
39pub struct WaypointConfig {
40    pub database: DatabaseConfig,
41    pub migrations: MigrationSettings,
42    pub hooks: HooksConfig,
43    pub placeholders: HashMap<String, String>,
44}
45
46/// Database connection configuration.
47#[derive(Clone)]
48pub struct DatabaseConfig {
49    pub url: Option<String>,
50    pub host: Option<String>,
51    pub port: Option<u16>,
52    pub user: Option<String>,
53    pub password: Option<String>,
54    pub database: Option<String>,
55    pub connect_retries: u32,
56    pub ssl_mode: SslMode,
57    pub connect_timeout_secs: u32,
58    pub statement_timeout_secs: u32,
59}
60
61impl Default for DatabaseConfig {
62    fn default() -> Self {
63        Self {
64            url: None,
65            host: None,
66            port: None,
67            user: None,
68            password: None,
69            database: None,
70            connect_retries: 0,
71            ssl_mode: SslMode::Prefer,
72            connect_timeout_secs: 30,
73            statement_timeout_secs: 0,
74        }
75    }
76}
77
78impl fmt::Debug for DatabaseConfig {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        f.debug_struct("DatabaseConfig")
81            .field("url", &self.url.as_ref().map(|_| "[REDACTED]"))
82            .field("host", &self.host)
83            .field("port", &self.port)
84            .field("user", &self.user)
85            .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
86            .field("database", &self.database)
87            .field("connect_retries", &self.connect_retries)
88            .field("ssl_mode", &self.ssl_mode)
89            .field("connect_timeout_secs", &self.connect_timeout_secs)
90            .field("statement_timeout_secs", &self.statement_timeout_secs)
91            .finish()
92    }
93}
94
95/// Hook configuration for running SQL before/after migrations.
96#[derive(Debug, Clone, Default)]
97pub struct HooksConfig {
98    pub before_migrate: Vec<PathBuf>,
99    pub after_migrate: Vec<PathBuf>,
100    pub before_each_migrate: Vec<PathBuf>,
101    pub after_each_migrate: Vec<PathBuf>,
102}
103
104/// Migration behavior settings.
105#[derive(Debug, Clone)]
106pub struct MigrationSettings {
107    pub locations: Vec<PathBuf>,
108    pub table: String,
109    pub schema: String,
110    pub out_of_order: bool,
111    pub validate_on_migrate: bool,
112    pub clean_enabled: bool,
113    pub baseline_version: String,
114    pub installed_by: Option<String>,
115}
116
117impl Default for MigrationSettings {
118    fn default() -> Self {
119        Self {
120            locations: vec![PathBuf::from("db/migrations")],
121            table: "waypoint_schema_history".to_string(),
122            schema: "public".to_string(),
123            out_of_order: false,
124            validate_on_migrate: true,
125            clean_enabled: false,
126            baseline_version: "1".to_string(),
127            installed_by: None,
128        }
129    }
130}
131
132// ── TOML deserialization structs ──
133
134#[derive(Deserialize, Default)]
135struct TomlConfig {
136    database: Option<TomlDatabaseConfig>,
137    migrations: Option<TomlMigrationSettings>,
138    hooks: Option<TomlHooksConfig>,
139    placeholders: Option<HashMap<String, String>>,
140}
141
142#[derive(Deserialize, Default)]
143struct TomlDatabaseConfig {
144    url: Option<String>,
145    host: Option<String>,
146    port: Option<u16>,
147    user: Option<String>,
148    password: Option<String>,
149    database: Option<String>,
150    connect_retries: Option<u32>,
151    ssl_mode: Option<String>,
152    connect_timeout: Option<u32>,
153    statement_timeout: Option<u32>,
154}
155
156#[derive(Deserialize, Default)]
157struct TomlMigrationSettings {
158    locations: Option<Vec<String>>,
159    table: Option<String>,
160    schema: Option<String>,
161    out_of_order: Option<bool>,
162    validate_on_migrate: Option<bool>,
163    clean_enabled: Option<bool>,
164    baseline_version: Option<String>,
165    installed_by: Option<String>,
166}
167
168#[derive(Deserialize, Default)]
169struct TomlHooksConfig {
170    before_migrate: Option<Vec<String>>,
171    after_migrate: Option<Vec<String>>,
172    before_each_migrate: Option<Vec<String>>,
173    after_each_migrate: Option<Vec<String>>,
174}
175
176/// CLI overrides that take highest priority.
177#[derive(Debug, Default, Clone)]
178pub struct CliOverrides {
179    pub url: Option<String>,
180    pub schema: Option<String>,
181    pub table: Option<String>,
182    pub locations: Option<Vec<PathBuf>>,
183    pub out_of_order: Option<bool>,
184    pub validate_on_migrate: Option<bool>,
185    pub baseline_version: Option<String>,
186    pub connect_retries: Option<u32>,
187    pub ssl_mode: Option<String>,
188    pub connect_timeout: Option<u32>,
189    pub statement_timeout: Option<u32>,
190}
191
192impl WaypointConfig {
193    /// Load configuration with the following priority (highest wins):
194    /// 1. CLI arguments
195    /// 2. Environment variables
196    /// 3. TOML config file
197    /// 4. Built-in defaults
198    pub fn load(config_path: Option<&str>, overrides: &CliOverrides) -> Result<Self> {
199        let mut config = WaypointConfig::default();
200
201        // Layer 3: TOML config file
202        let toml_path = config_path.unwrap_or("waypoint.toml");
203        if let Ok(content) = std::fs::read_to_string(toml_path) {
204            // Warn if config file has overly permissive permissions (Unix only)
205            #[cfg(unix)]
206            {
207                use std::os::unix::fs::PermissionsExt;
208                if let Ok(meta) = std::fs::metadata(toml_path) {
209                    let mode = meta.permissions().mode();
210                    if mode & 0o077 != 0 {
211                        tracing::warn!(
212                            path = %toml_path,
213                            mode = format!("{:o}", mode),
214                            "Config file has overly permissive permissions. Consider chmod 600."
215                        );
216                    }
217                }
218            }
219            let toml_config: TomlConfig = toml::from_str(&content).map_err(|e| {
220                WaypointError::ConfigError(format!(
221                    "Failed to parse config file '{}': {}",
222                    toml_path, e
223                ))
224            })?;
225            config.apply_toml(toml_config);
226        } else if config_path.is_some() {
227            // If explicitly specified, error if not found
228            return Err(WaypointError::ConfigError(format!(
229                "Config file '{}' not found",
230                toml_path
231            )));
232        }
233
234        // Layer 2: Environment variables
235        config.apply_env();
236
237        // Layer 1: CLI overrides
238        config.apply_cli(overrides);
239
240        // Validate identifiers
241        crate::db::validate_identifier(&config.migrations.schema)?;
242        crate::db::validate_identifier(&config.migrations.table)?;
243
244        // Cap connect_retries at 20
245        if config.database.connect_retries > 20 {
246            config.database.connect_retries = 20;
247            tracing::warn!("connect_retries capped at 20");
248        }
249
250        Ok(config)
251    }
252
253    fn apply_toml(&mut self, toml: TomlConfig) {
254        if let Some(db) = toml.database {
255            if let Some(v) = db.url {
256                self.database.url = Some(v);
257            }
258            if let Some(v) = db.host {
259                self.database.host = Some(v);
260            }
261            if let Some(v) = db.port {
262                self.database.port = Some(v);
263            }
264            if let Some(v) = db.user {
265                self.database.user = Some(v);
266            }
267            if let Some(v) = db.password {
268                self.database.password = Some(v);
269            }
270            if let Some(v) = db.database {
271                self.database.database = Some(v);
272            }
273            if let Some(v) = db.connect_retries {
274                self.database.connect_retries = v;
275            }
276            if let Some(v) = db.ssl_mode {
277                if let Ok(mode) = v.parse() {
278                    self.database.ssl_mode = mode;
279                }
280            }
281            if let Some(v) = db.connect_timeout {
282                self.database.connect_timeout_secs = v;
283            }
284            if let Some(v) = db.statement_timeout {
285                self.database.statement_timeout_secs = v;
286            }
287        }
288
289        if let Some(m) = toml.migrations {
290            if let Some(v) = m.locations {
291                self.migrations.locations = v.into_iter().map(|s| normalize_location(&s)).collect();
292            }
293            if let Some(v) = m.table {
294                self.migrations.table = v;
295            }
296            if let Some(v) = m.schema {
297                self.migrations.schema = v;
298            }
299            if let Some(v) = m.out_of_order {
300                self.migrations.out_of_order = v;
301            }
302            if let Some(v) = m.validate_on_migrate {
303                self.migrations.validate_on_migrate = v;
304            }
305            if let Some(v) = m.clean_enabled {
306                self.migrations.clean_enabled = v;
307            }
308            if let Some(v) = m.baseline_version {
309                self.migrations.baseline_version = v;
310            }
311            if let Some(v) = m.installed_by {
312                self.migrations.installed_by = Some(v);
313            }
314        }
315
316        if let Some(h) = toml.hooks {
317            if let Some(v) = h.before_migrate {
318                self.hooks.before_migrate = v.into_iter().map(PathBuf::from).collect();
319            }
320            if let Some(v) = h.after_migrate {
321                self.hooks.after_migrate = v.into_iter().map(PathBuf::from).collect();
322            }
323            if let Some(v) = h.before_each_migrate {
324                self.hooks.before_each_migrate = v.into_iter().map(PathBuf::from).collect();
325            }
326            if let Some(v) = h.after_each_migrate {
327                self.hooks.after_each_migrate = v.into_iter().map(PathBuf::from).collect();
328            }
329        }
330
331        if let Some(p) = toml.placeholders {
332            self.placeholders.extend(p);
333        }
334    }
335
336    fn apply_env(&mut self) {
337        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_URL") {
338            self.database.url = Some(v);
339        }
340        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_HOST") {
341            self.database.host = Some(v);
342        }
343        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PORT") {
344            if let Ok(port) = v.parse::<u16>() {
345                self.database.port = Some(port);
346            }
347        }
348        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_USER") {
349            self.database.user = Some(v);
350        }
351        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_PASSWORD") {
352            self.database.password = Some(v);
353        }
354        if let Ok(v) = std::env::var("WAYPOINT_DATABASE_NAME") {
355            self.database.database = Some(v);
356        }
357        if let Ok(v) = std::env::var("WAYPOINT_CONNECT_RETRIES") {
358            if let Ok(n) = v.parse::<u32>() {
359                self.database.connect_retries = n;
360            }
361        }
362        if let Ok(v) = std::env::var("WAYPOINT_SSL_MODE") {
363            if let Ok(mode) = v.parse() {
364                self.database.ssl_mode = mode;
365            }
366        }
367        if let Ok(v) = std::env::var("WAYPOINT_CONNECT_TIMEOUT") {
368            if let Ok(n) = v.parse::<u32>() {
369                self.database.connect_timeout_secs = n;
370            }
371        }
372        if let Ok(v) = std::env::var("WAYPOINT_STATEMENT_TIMEOUT") {
373            if let Ok(n) = v.parse::<u32>() {
374                self.database.statement_timeout_secs = n;
375            }
376        }
377        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_LOCATIONS") {
378            self.migrations.locations =
379                v.split(',').map(|s| normalize_location(s.trim())).collect();
380        }
381        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_TABLE") {
382            self.migrations.table = v;
383        }
384        if let Ok(v) = std::env::var("WAYPOINT_MIGRATIONS_SCHEMA") {
385            self.migrations.schema = v;
386        }
387
388        // Scan for placeholder env vars: WAYPOINT_PLACEHOLDER_{KEY}
389        for (key, value) in std::env::vars() {
390            if let Some(placeholder_key) = key.strip_prefix("WAYPOINT_PLACEHOLDER_") {
391                self.placeholders
392                    .insert(placeholder_key.to_lowercase(), value);
393            }
394        }
395    }
396
397    fn apply_cli(&mut self, overrides: &CliOverrides) {
398        if let Some(ref v) = overrides.url {
399            self.database.url = Some(v.clone());
400        }
401        if let Some(ref v) = overrides.schema {
402            self.migrations.schema = v.clone();
403        }
404        if let Some(ref v) = overrides.table {
405            self.migrations.table = v.clone();
406        }
407        if let Some(ref v) = overrides.locations {
408            self.migrations.locations = v.clone();
409        }
410        if let Some(v) = overrides.out_of_order {
411            self.migrations.out_of_order = v;
412        }
413        if let Some(v) = overrides.validate_on_migrate {
414            self.migrations.validate_on_migrate = v;
415        }
416        if let Some(ref v) = overrides.baseline_version {
417            self.migrations.baseline_version = v.clone();
418        }
419        if let Some(v) = overrides.connect_retries {
420            self.database.connect_retries = v;
421        }
422        if let Some(ref v) = overrides.ssl_mode {
423            // Ignore parse errors here — they'll be caught in validation
424            if let Ok(mode) = v.parse() {
425                self.database.ssl_mode = mode;
426            }
427        }
428        if let Some(v) = overrides.connect_timeout {
429            self.database.connect_timeout_secs = v;
430        }
431        if let Some(v) = overrides.statement_timeout {
432            self.database.statement_timeout_secs = v;
433        }
434    }
435
436    /// Build a connection string from the config.
437    /// Prefers `url` if set; otherwise builds from individual fields.
438    /// Handles JDBC-style URLs by stripping the `jdbc:` prefix and
439    /// extracting `user` and `password` query parameters.
440    pub fn connection_string(&self) -> Result<String> {
441        if let Some(ref url) = self.database.url {
442            return Ok(normalize_jdbc_url(url));
443        }
444
445        let host = self.database.host.as_deref().unwrap_or("localhost");
446        let port = self.database.port.unwrap_or(5432);
447        let user =
448            self.database.user.as_deref().ok_or_else(|| {
449                WaypointError::ConfigError("Database user is required".to_string())
450            })?;
451        let database =
452            self.database.database.as_deref().ok_or_else(|| {
453                WaypointError::ConfigError("Database name is required".to_string())
454            })?;
455
456        let mut url = format!(
457            "host={} port={} user={} dbname={}",
458            host, port, user, database
459        );
460
461        if let Some(ref password) = self.database.password {
462            url.push_str(&format!(" password={}", password));
463        }
464
465        Ok(url)
466    }
467}
468
469/// Normalize a JDBC-style URL to a standard PostgreSQL connection string.
470///
471/// Handles:
472///   - `jdbc:postgresql://host:port/db?user=x&password=y`  →  `postgresql://x:y@host:port/db`
473///   - `postgresql://...` passed through as-is
474///   - `postgres://...` passed through as-is
475fn normalize_jdbc_url(url: &str) -> String {
476    // Strip jdbc: prefix
477    let url = url.strip_prefix("jdbc:").unwrap_or(url);
478
479    // Parse query parameters for user/password if present
480    if let Some((base, query)) = url.split_once('?') {
481        let mut user = None;
482        let mut password = None;
483        let mut other_params = Vec::new();
484
485        for param in query.split('&') {
486            if let Some((key, value)) = param.split_once('=') {
487                match key.to_lowercase().as_str() {
488                    "user" => user = Some(value.to_string()),
489                    "password" => password = Some(value.to_string()),
490                    _ => other_params.push(param.to_string()),
491                }
492            }
493        }
494
495        // If we extracted user/password, rebuild the URL with credentials in the authority
496        if user.is_some() || password.is_some() {
497            if let Some(rest) = base
498                .strip_prefix("postgresql://")
499                .or_else(|| base.strip_prefix("postgres://"))
500            {
501                let scheme = if base.starts_with("postgresql://") {
502                    "postgresql"
503                } else {
504                    "postgres"
505                };
506
507                let auth = match (user, password) {
508                    (Some(u), Some(p)) => format!("{}:{}@", u, p),
509                    (Some(u), None) => format!("{}@", u),
510                    (None, Some(p)) => format!(":{p}@"),
511                    (None, None) => String::new(),
512                };
513
514                let mut result = format!("{}://{}{}", scheme, auth, rest);
515                if !other_params.is_empty() {
516                    result.push('?');
517                    result.push_str(&other_params.join("&"));
518                }
519                return result;
520            }
521        }
522
523        // No user/password in query, return with jdbc: stripped
524        if other_params.is_empty() {
525            return base.to_string();
526        }
527        return format!("{}?{}", base, other_params.join("&"));
528    }
529
530    url.to_string()
531}
532
533/// Strip `filesystem:` prefix from a location path (Flyway compatibility).
534pub fn normalize_location(location: &str) -> PathBuf {
535    let stripped = location.strip_prefix("filesystem:").unwrap_or(location);
536    PathBuf::from(stripped)
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_default_config() {
545        let config = WaypointConfig::default();
546        assert_eq!(config.migrations.table, "waypoint_schema_history");
547        assert_eq!(config.migrations.schema, "public");
548        assert!(!config.migrations.out_of_order);
549        assert!(config.migrations.validate_on_migrate);
550        assert!(!config.migrations.clean_enabled);
551        assert_eq!(config.migrations.baseline_version, "1");
552        assert_eq!(
553            config.migrations.locations,
554            vec![PathBuf::from("db/migrations")]
555        );
556    }
557
558    #[test]
559    fn test_connection_string_from_url() {
560        let mut config = WaypointConfig::default();
561        config.database.url = Some("postgres://user:pass@localhost/db".to_string());
562        assert_eq!(
563            config.connection_string().unwrap(),
564            "postgres://user:pass@localhost/db"
565        );
566    }
567
568    #[test]
569    fn test_connection_string_from_fields() {
570        let mut config = WaypointConfig::default();
571        config.database.host = Some("myhost".to_string());
572        config.database.port = Some(5433);
573        config.database.user = Some("myuser".to_string());
574        config.database.database = Some("mydb".to_string());
575        config.database.password = Some("secret".to_string());
576
577        let conn = config.connection_string().unwrap();
578        assert!(conn.contains("host=myhost"));
579        assert!(conn.contains("port=5433"));
580        assert!(conn.contains("user=myuser"));
581        assert!(conn.contains("dbname=mydb"));
582        assert!(conn.contains("password=secret"));
583    }
584
585    #[test]
586    fn test_connection_string_missing_user() {
587        let mut config = WaypointConfig::default();
588        config.database.database = Some("mydb".to_string());
589        assert!(config.connection_string().is_err());
590    }
591
592    #[test]
593    fn test_cli_overrides() {
594        let mut config = WaypointConfig::default();
595        let overrides = CliOverrides {
596            url: Some("postgres://override@localhost/db".to_string()),
597            schema: Some("custom_schema".to_string()),
598            table: Some("custom_table".to_string()),
599            locations: Some(vec![PathBuf::from("custom/path")]),
600            out_of_order: Some(true),
601            validate_on_migrate: Some(false),
602            baseline_version: Some("5".to_string()),
603            connect_retries: None,
604            ssl_mode: None,
605            connect_timeout: None,
606            statement_timeout: None,
607        };
608
609        config.apply_cli(&overrides);
610
611        assert_eq!(
612            config.database.url.as_deref(),
613            Some("postgres://override@localhost/db")
614        );
615        assert_eq!(config.migrations.schema, "custom_schema");
616        assert_eq!(config.migrations.table, "custom_table");
617        assert_eq!(
618            config.migrations.locations,
619            vec![PathBuf::from("custom/path")]
620        );
621        assert!(config.migrations.out_of_order);
622        assert!(!config.migrations.validate_on_migrate);
623        assert_eq!(config.migrations.baseline_version, "5");
624    }
625
626    #[test]
627    fn test_toml_parsing() {
628        let toml_str = r#"
629[database]
630url = "postgres://user:pass@localhost/mydb"
631
632[migrations]
633table = "my_history"
634schema = "app"
635out_of_order = true
636locations = ["sql/migrations", "sql/seeds"]
637
638[placeholders]
639env = "production"
640app_name = "myapp"
641"#;
642
643        let toml_config: TomlConfig = toml::from_str(toml_str).unwrap();
644        let mut config = WaypointConfig::default();
645        config.apply_toml(toml_config);
646
647        assert_eq!(
648            config.database.url.as_deref(),
649            Some("postgres://user:pass@localhost/mydb")
650        );
651        assert_eq!(config.migrations.table, "my_history");
652        assert_eq!(config.migrations.schema, "app");
653        assert!(config.migrations.out_of_order);
654        assert_eq!(
655            config.migrations.locations,
656            vec![PathBuf::from("sql/migrations"), PathBuf::from("sql/seeds")]
657        );
658        assert_eq!(config.placeholders.get("env").unwrap(), "production");
659        assert_eq!(config.placeholders.get("app_name").unwrap(), "myapp");
660    }
661
662    #[test]
663    fn test_normalize_jdbc_url_with_credentials() {
664        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret";
665        assert_eq!(
666            normalize_jdbc_url(url),
667            "postgresql://admin:secret@myhost:5432/mydb"
668        );
669    }
670
671    #[test]
672    fn test_normalize_jdbc_url_user_only() {
673        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin";
674        assert_eq!(
675            normalize_jdbc_url(url),
676            "postgresql://admin@myhost:5432/mydb"
677        );
678    }
679
680    #[test]
681    fn test_normalize_jdbc_url_strips_jdbc_prefix() {
682        let url = "jdbc:postgresql://myhost:5432/mydb";
683        assert_eq!(normalize_jdbc_url(url), "postgresql://myhost:5432/mydb");
684    }
685
686    #[test]
687    fn test_normalize_jdbc_url_passthrough() {
688        let url = "postgresql://user:pass@myhost:5432/mydb";
689        assert_eq!(normalize_jdbc_url(url), url);
690    }
691
692    #[test]
693    fn test_normalize_jdbc_url_preserves_other_params() {
694        let url = "jdbc:postgresql://myhost:5432/mydb?user=admin&password=secret&sslmode=require";
695        assert_eq!(
696            normalize_jdbc_url(url),
697            "postgresql://admin:secret@myhost:5432/mydb?sslmode=require"
698        );
699    }
700
701    #[test]
702    fn test_normalize_location_filesystem_prefix() {
703        assert_eq!(
704            normalize_location("filesystem:/flyway/sql"),
705            PathBuf::from("/flyway/sql")
706        );
707    }
708
709    #[test]
710    fn test_normalize_location_plain_path() {
711        assert_eq!(
712            normalize_location("/my/migrations"),
713            PathBuf::from("/my/migrations")
714        );
715    }
716
717    #[test]
718    fn test_normalize_location_relative() {
719        assert_eq!(
720            normalize_location("filesystem:db/migrations"),
721            PathBuf::from("db/migrations")
722        );
723    }
724}