Skip to main content

sentinel_driver/
config.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::error::{Error, Result};
5
6/// TLS mode for the connection.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum SslMode {
9    /// No TLS. Connections are unencrypted.
10    Disable,
11    /// Try TLS, fall back to plaintext if server doesn't support it.
12    #[default]
13    Prefer,
14    /// Require TLS. Fail if server doesn't support it.
15    Require,
16    /// Require TLS and verify the server certificate.
17    VerifyCa,
18    /// Require TLS, verify certificate, and verify hostname matches.
19    VerifyFull,
20}
21
22/// Connection configuration for sentinel-driver.
23///
24/// # Connection String
25///
26/// ```text
27/// postgres://user:password@host:port/database?sslmode=prefer&application_name=myapp
28/// ```
29///
30/// # Builder
31///
32/// ```rust,no_run
33/// use sentinel_driver::Config;
34///
35/// let config = Config::builder()
36///     .host("localhost")
37///     .port(5432)
38///     .database("mydb")
39///     .user("postgres")
40///     .password("secret")
41///     .build();
42/// ```
43#[derive(Debug, Clone)]
44pub struct Config {
45    pub(crate) hosts: Vec<(String, u16)>,
46    pub(crate) database: String,
47    pub(crate) user: String,
48    pub(crate) password: Option<String>,
49    pub(crate) ssl_mode: SslMode,
50    pub(crate) application_name: Option<String>,
51    pub(crate) connect_timeout: Duration,
52    pub(crate) statement_timeout: Option<Duration>,
53    pub(crate) _keepalive: Option<Duration>,
54    pub(crate) _keepalive_idle: Option<Duration>,
55    pub(crate) target_session_attrs: TargetSessionAttrs,
56    pub(crate) _extra_float_digits: Option<i32>,
57    pub(crate) load_balance_hosts: LoadBalanceHosts,
58    /// Path to client certificate file for certificate authentication.
59    pub(crate) ssl_client_cert: Option<std::path::PathBuf>,
60    /// Path to client private key file for certificate authentication.
61    pub(crate) ssl_client_key: Option<std::path::PathBuf>,
62    /// Use direct TLS connection (PG 17+) — skip SSLRequest negotiation.
63    pub(crate) ssl_direct: bool,
64    /// Enable SCRAM-SHA-256 channel binding (SCRAM-PLUS) when TLS is active.
65    pub(crate) channel_binding: ChannelBinding,
66}
67
68/// Channel binding preference for SCRAM authentication.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum ChannelBinding {
71    /// Use channel binding if available (default).
72    #[default]
73    Prefer,
74    /// Require channel binding — fail if server doesn't support it.
75    Require,
76    /// Disable channel binding.
77    Disable,
78}
79
80/// Target session attributes for connection validation.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
82pub enum TargetSessionAttrs {
83    /// Any server is acceptable.
84    #[default]
85    Any,
86    /// Only accept read-write servers (primary).
87    ReadWrite,
88    /// Only accept read-only servers (replica).
89    ReadOnly,
90}
91
92/// Load balancing strategy for multi-host connections.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
94pub enum LoadBalanceHosts {
95    /// Try hosts in order (default).
96    #[default]
97    Disable,
98    /// Shuffle hosts before trying.
99    Random,
100}
101
102impl Config {
103    /// Parse a PostgreSQL connection string.
104    ///
105    /// Supported formats:
106    /// - `postgres://user:password@host:port/database?param=value`
107    /// - `postgresql://user:password@host:port/database?param=value`
108    pub fn parse(s: &str) -> Result<Self> {
109        let s = s.trim();
110
111        let without_scheme = s
112            .strip_prefix("postgres://")
113            .or_else(|| s.strip_prefix("postgresql://"))
114            .ok_or_else(|| {
115                Error::Config(
116                    "connection string must start with postgres:// or postgresql://".into(),
117                )
118            })?;
119
120        let (userinfo, rest) = match without_scheme.split_once('@') {
121            Some((ui, rest)) => (Some(ui), rest),
122            None => (None, without_scheme),
123        };
124
125        let (user, password) = match userinfo {
126            Some(ui) => match ui.split_once(':') {
127                Some((u, p)) => (percent_decode(u)?, Some(percent_decode(p)?)),
128                None => (percent_decode(ui)?, None),
129            },
130            None => (String::new(), None),
131        };
132
133        // Split host:port from database?params
134        let (hostport, db_and_params) = match rest.split_once('/') {
135            Some((hp, rest)) => (hp, Some(rest)),
136            None => (rest, None),
137        };
138
139        // Parse comma-separated host:port pairs
140        let mut hosts: Vec<(String, u16)> = Vec::new();
141        if hostport.is_empty() {
142            // Empty host — will be set via ?host= parameter (Unix socket)
143        } else {
144            for entry in hostport.split(',') {
145                let (h, p) = match entry.rsplit_once(':') {
146                    Some((h, p)) => {
147                        let port: u16 = p
148                            .parse()
149                            .map_err(|_| Error::Config(format!("invalid port: {p}")))?;
150                        (h.to_string(), port)
151                    }
152                    None => (entry.to_string(), 5432),
153                };
154                hosts.push((h, p));
155            }
156        }
157
158        let (database, params_str) = match db_and_params {
159            Some(dp) => match dp.split_once('?') {
160                Some((db, params)) => (percent_decode(db)?, Some(params.to_string())),
161                None => (percent_decode(dp)?, None),
162            },
163            None => (String::new(), None),
164        };
165
166        let mut config = ConfigBuilder::new();
167        for (h, p) in &hosts {
168            config = config.host_port(h.clone(), *p);
169        }
170        config = config.database(database).user(user);
171
172        if let Some(pw) = password {
173            config = config.password(pw);
174        }
175
176        // Parse query parameters
177        if let Some(params) = params_str {
178            for param in params.split('&') {
179                let (key, value) = param
180                    .split_once('=')
181                    .ok_or_else(|| Error::Config(format!("invalid parameter: {param}")))?;
182                let value = percent_decode(value)?;
183
184                match key {
185                    "sslmode" => {
186                        config = config.ssl_mode(match value.as_str() {
187                            "disable" => SslMode::Disable,
188                            "prefer" => SslMode::Prefer,
189                            "require" => SslMode::Require,
190                            "verify-ca" => SslMode::VerifyCa,
191                            "verify-full" => SslMode::VerifyFull,
192                            _ => return Err(Error::Config(format!("invalid sslmode: {value}"))),
193                        });
194                    }
195                    "application_name" => {
196                        config = config.application_name(value);
197                    }
198                    "connect_timeout" => {
199                        let secs: u64 = value.parse().map_err(|_| {
200                            Error::Config(format!("invalid connect_timeout: {value}"))
201                        })?;
202                        config = config.connect_timeout(Duration::from_secs(secs));
203                    }
204                    "statement_timeout" => {
205                        let secs: u64 = value.parse().map_err(|_| {
206                            Error::Config(format!("invalid statement_timeout: {value}"))
207                        })?;
208                        config = config.statement_timeout(Duration::from_secs(secs));
209                    }
210                    "target_session_attrs" => {
211                        config = config.target_session_attrs(match value.as_str() {
212                            "any" => TargetSessionAttrs::Any,
213                            "read-write" => TargetSessionAttrs::ReadWrite,
214                            "read-only" => TargetSessionAttrs::ReadOnly,
215                            _ => {
216                                return Err(Error::Config(format!(
217                                    "invalid target_session_attrs: {value}"
218                                )))
219                            }
220                        });
221                    }
222                    "sslcert" => {
223                        config = config.ssl_client_cert(PathBuf::from(value));
224                    }
225                    "sslkey" => {
226                        config = config.ssl_client_key(PathBuf::from(value));
227                    }
228                    "ssldirect" | "sslnegotiation" => {
229                        let direct = match value.as_str() {
230                            "true" | "direct" => true,
231                            "false" | "postgres" => false,
232                            _ => return Err(Error::Config(format!("invalid {key}: {value}"))),
233                        };
234                        config = config.ssl_direct(direct);
235                    }
236                    "channel_binding" => {
237                        config = config.channel_binding(match value.as_str() {
238                            "prefer" => ChannelBinding::Prefer,
239                            "require" => ChannelBinding::Require,
240                            "disable" => ChannelBinding::Disable,
241                            _ => {
242                                return Err(Error::Config(format!(
243                                    "invalid channel_binding: {value}"
244                                )))
245                            }
246                        });
247                    }
248                    "load_balance_hosts" => {
249                        config = config.load_balance_hosts(match value.as_str() {
250                            "disable" => LoadBalanceHosts::Disable,
251                            "random" => LoadBalanceHosts::Random,
252                            _ => {
253                                return Err(Error::Config(format!(
254                                    "invalid load_balance_hosts: {value}"
255                                )))
256                            }
257                        });
258                    }
259                    "host" => {
260                        // Support ?host=/var/run/postgresql for Unix sockets
261                        config = config.host_port(value, 5432);
262                    }
263                    _ => {
264                        // Ignore unknown parameters for forward compatibility
265                    }
266                }
267            }
268        }
269
270        Ok(config.build())
271    }
272
273    /// Create a new builder for `Config`.
274    pub fn builder() -> ConfigBuilder {
275        ConfigBuilder::new()
276    }
277
278    // Accessor methods
279
280    /// Returns the first host (for backward compatibility and single-host use).
281    pub fn host(&self) -> &str {
282        self.hosts.first().map_or("localhost", |(h, _)| h.as_str())
283    }
284
285    /// Returns the first port (for backward compatibility and single-host use).
286    pub fn port(&self) -> u16 {
287        self.hosts.first().map_or(5432, |(_, p)| *p)
288    }
289
290    /// Returns all configured host/port pairs.
291    pub fn hosts(&self) -> &[(String, u16)] {
292        &self.hosts
293    }
294
295    /// Load balancing strategy for multi-host connections.
296    pub fn load_balance_hosts(&self) -> LoadBalanceHosts {
297        self.load_balance_hosts
298    }
299
300    /// Target session attributes for connection routing.
301    pub fn target_session_attrs(&self) -> TargetSessionAttrs {
302        self.target_session_attrs
303    }
304
305    pub fn database(&self) -> &str {
306        &self.database
307    }
308
309    pub fn user(&self) -> &str {
310        &self.user
311    }
312
313    pub fn password(&self) -> Option<&str> {
314        self.password.as_deref()
315    }
316
317    pub fn ssl_mode(&self) -> SslMode {
318        self.ssl_mode
319    }
320
321    pub fn application_name(&self) -> Option<&str> {
322        self.application_name.as_deref()
323    }
324
325    pub fn connect_timeout(&self) -> Duration {
326        self.connect_timeout
327    }
328
329    pub fn statement_timeout(&self) -> Option<Duration> {
330        self.statement_timeout
331    }
332
333    /// Path to client certificate for certificate authentication.
334    pub fn ssl_client_cert(&self) -> Option<&std::path::Path> {
335        self.ssl_client_cert.as_deref()
336    }
337
338    /// Path to client private key for certificate authentication.
339    pub fn ssl_client_key(&self) -> Option<&std::path::Path> {
340        self.ssl_client_key.as_deref()
341    }
342
343    /// Whether direct TLS (PG 17+) is enabled.
344    pub fn ssl_direct(&self) -> bool {
345        self.ssl_direct
346    }
347
348    /// Channel binding preference for SCRAM authentication.
349    pub fn channel_binding(&self) -> ChannelBinding {
350        self.channel_binding
351    }
352}
353
354/// Builder for [`Config`].
355#[derive(Debug, Clone)]
356pub struct ConfigBuilder {
357    hosts: Vec<(String, u16)>,
358    default_port: u16,
359    database: String,
360    user: String,
361    password: Option<String>,
362    ssl_mode: SslMode,
363    application_name: Option<String>,
364    connect_timeout: Duration,
365    statement_timeout: Option<Duration>,
366    keepalive: Option<Duration>,
367    keepalive_idle: Option<Duration>,
368    target_session_attrs: TargetSessionAttrs,
369    extra_float_digits: Option<i32>,
370    load_balance_hosts: LoadBalanceHosts,
371    ssl_client_cert: Option<PathBuf>,
372    ssl_client_key: Option<PathBuf>,
373    ssl_direct: bool,
374    channel_binding: ChannelBinding,
375}
376
377impl ConfigBuilder {
378    fn new() -> Self {
379        Self {
380            hosts: Vec::new(),
381            default_port: 5432,
382            database: String::new(),
383            user: String::new(),
384            password: None,
385            ssl_mode: SslMode::default(),
386            application_name: None,
387            connect_timeout: Duration::from_secs(10),
388            statement_timeout: None,
389            keepalive: Some(Duration::from_secs(60)),
390            keepalive_idle: None,
391            target_session_attrs: TargetSessionAttrs::default(),
392            extra_float_digits: Some(3),
393            load_balance_hosts: LoadBalanceHosts::default(),
394            ssl_client_cert: None,
395            ssl_client_key: None,
396            ssl_direct: false,
397            channel_binding: ChannelBinding::default(),
398        }
399    }
400
401    /// Append a host with the current default port.
402    pub fn host(mut self, host: impl Into<String>) -> Self {
403        self.hosts.push((host.into(), self.default_port));
404        self
405    }
406
407    /// Append a host with a specific port.
408    pub fn host_port(mut self, host: impl Into<String>, port: u16) -> Self {
409        self.hosts.push((host.into(), port));
410        self
411    }
412
413    /// Set the default port for subsequent `.host()` calls and update
414    /// any hosts that still have the old default port.
415    pub fn port(mut self, port: u16) -> Self {
416        let old_default = self.default_port;
417        self.default_port = port;
418        for (_, p) in &mut self.hosts {
419            if *p == old_default {
420                *p = port;
421            }
422        }
423        self
424    }
425
426    pub fn load_balance_hosts(mut self, strategy: LoadBalanceHosts) -> Self {
427        self.load_balance_hosts = strategy;
428        self
429    }
430
431    pub fn database(mut self, database: impl Into<String>) -> Self {
432        self.database = database.into();
433        self
434    }
435
436    pub fn user(mut self, user: impl Into<String>) -> Self {
437        self.user = user.into();
438        self
439    }
440
441    pub fn password(mut self, password: impl Into<String>) -> Self {
442        self.password = Some(password.into());
443        self
444    }
445
446    pub fn ssl_mode(mut self, ssl_mode: SslMode) -> Self {
447        self.ssl_mode = ssl_mode;
448        self
449    }
450
451    pub fn application_name(mut self, name: impl Into<String>) -> Self {
452        self.application_name = Some(name.into());
453        self
454    }
455
456    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
457        self.connect_timeout = timeout;
458        self
459    }
460
461    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
462        self.statement_timeout = Some(timeout);
463        self
464    }
465
466    pub fn keepalive(mut self, interval: Duration) -> Self {
467        self.keepalive = Some(interval);
468        self
469    }
470
471    pub fn target_session_attrs(mut self, attrs: TargetSessionAttrs) -> Self {
472        self.target_session_attrs = attrs;
473        self
474    }
475
476    /// Set the path to the client certificate file for certificate authentication.
477    pub fn ssl_client_cert(mut self, path: impl Into<PathBuf>) -> Self {
478        self.ssl_client_cert = Some(path.into());
479        self
480    }
481
482    /// Set the path to the client private key file for certificate authentication.
483    pub fn ssl_client_key(mut self, path: impl Into<PathBuf>) -> Self {
484        self.ssl_client_key = Some(path.into());
485        self
486    }
487
488    /// Enable direct TLS connection (PG 17+), skipping SSLRequest negotiation.
489    pub fn ssl_direct(mut self, direct: bool) -> Self {
490        self.ssl_direct = direct;
491        self
492    }
493
494    /// Set the channel binding preference for SCRAM authentication.
495    pub fn channel_binding(mut self, binding: ChannelBinding) -> Self {
496        self.channel_binding = binding;
497        self
498    }
499
500    /// Build the final `Config`.
501    pub fn build(self) -> Config {
502        let hosts = if self.hosts.is_empty() {
503            vec![("localhost".to_string(), self.default_port)]
504        } else {
505            self.hosts
506        };
507        Config {
508            hosts,
509            database: self.database,
510            user: self.user,
511            password: self.password,
512            ssl_mode: self.ssl_mode,
513            application_name: self.application_name,
514            connect_timeout: self.connect_timeout,
515            statement_timeout: self.statement_timeout,
516            _keepalive: self.keepalive,
517            _keepalive_idle: self.keepalive_idle,
518            target_session_attrs: self.target_session_attrs,
519            _extra_float_digits: self.extra_float_digits,
520            load_balance_hosts: self.load_balance_hosts,
521            ssl_client_cert: self.ssl_client_cert,
522            ssl_client_key: self.ssl_client_key,
523            ssl_direct: self.ssl_direct,
524            channel_binding: self.channel_binding,
525        }
526    }
527}
528
529/// Percent-decode a URL component.
530fn percent_decode(s: &str) -> Result<String> {
531    let mut result = String::with_capacity(s.len());
532    let mut chars = s.as_bytes().iter();
533
534    while let Some(&b) = chars.next() {
535        if b == b'%' {
536            let hi = chars
537                .next()
538                .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
539            let lo = chars
540                .next()
541                .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
542            let byte = hex_digit(*hi)? << 4 | hex_digit(*lo)?;
543            result.push(byte as char);
544        } else {
545            result.push(b as char);
546        }
547    }
548
549    Ok(result)
550}
551
552fn hex_digit(b: u8) -> Result<u8> {
553    match b {
554        b'0'..=b'9' => Ok(b - b'0'),
555        b'a'..=b'f' => Ok(b - b'a' + 10),
556        b'A'..=b'F' => Ok(b - b'A' + 10),
557        _ => Err(Error::Config(format!("invalid hex digit: {}", b as char))),
558    }
559}