1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::error::{Error, Result};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum SslMode {
9 Disable,
11 #[default]
13 Prefer,
14 Require,
16 VerifyCa,
18 VerifyFull,
20}
21
22#[derive(Clone)]
44pub struct Config {
45 pub(crate) hosts: Vec<(String, u16)>,
46 pub(crate) database: String,
47 pub(crate) user: String,
48 pub(crate) password: Option<String>,
49 pub(crate) ssl_mode: SslMode,
50 pub(crate) application_name: Option<String>,
51 pub(crate) connect_timeout: Duration,
52 pub(crate) statement_timeout: Option<Duration>,
53 pub(crate) keepalive: Option<Duration>,
54 pub(crate) keepalive_idle: Option<Duration>,
55 pub(crate) target_session_attrs: TargetSessionAttrs,
56 pub(crate) extra_float_digits: Option<i32>,
57 pub(crate) load_balance_hosts: LoadBalanceHosts,
58 pub(crate) ssl_client_cert: Option<std::path::PathBuf>,
60 pub(crate) ssl_client_key: Option<std::path::PathBuf>,
62 pub(crate) ssl_direct: bool,
64 pub(crate) channel_binding: ChannelBinding,
66 pub(crate) instrumentation: Option<std::sync::Arc<dyn crate::Instrumentation>>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
72pub enum ChannelBinding {
73 #[default]
75 Prefer,
76 Require,
78 Disable,
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum TargetSessionAttrs {
85 #[default]
87 Any,
88 ReadWrite,
90 ReadOnly,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum LoadBalanceHosts {
97 #[default]
99 Disable,
100 Random,
102}
103
104impl Config {
105 pub fn parse(s: &str) -> Result<Self> {
111 let s = s.trim();
112
113 let without_scheme = s
114 .strip_prefix("postgres://")
115 .or_else(|| s.strip_prefix("postgresql://"))
116 .ok_or_else(|| {
117 Error::Config(
118 "connection string must start with postgres:// or postgresql://".into(),
119 )
120 })?;
121
122 let (userinfo, rest) = match without_scheme.split_once('@') {
123 Some((ui, rest)) => (Some(ui), rest),
124 None => (None, without_scheme),
125 };
126
127 let (user, password) = match userinfo {
128 Some(ui) => match ui.split_once(':') {
129 Some((u, p)) => (percent_decode(u)?, Some(percent_decode(p)?)),
130 None => (percent_decode(ui)?, None),
131 },
132 None => (String::new(), None),
133 };
134
135 let (hostport, db_and_params) = match rest.split_once('/') {
137 Some((hp, rest)) => (hp, Some(rest)),
138 None => (rest, None),
139 };
140
141 let mut hosts: Vec<(String, u16)> = Vec::new();
143 if hostport.is_empty() {
144 } else {
146 for entry in hostport.split(',') {
147 let (h, p) = match entry.rsplit_once(':') {
148 Some((h, p)) => {
149 let port: u16 = p
150 .parse()
151 .map_err(|_| Error::Config(format!("invalid port: {p}")))?;
152 (h.to_string(), port)
153 }
154 None => (entry.to_string(), 5432),
155 };
156 hosts.push((h, p));
157 }
158 }
159
160 let (database, params_str) = match db_and_params {
161 Some(dp) => match dp.split_once('?') {
162 Some((db, params)) => (percent_decode(db)?, Some(params.to_string())),
163 None => (percent_decode(dp)?, None),
164 },
165 None => (String::new(), None),
166 };
167
168 let mut config = ConfigBuilder::new();
169 for (h, p) in &hosts {
170 config = config.host_port(h.clone(), *p);
171 }
172 config = config.database(database).user(user);
173
174 if let Some(pw) = password {
175 config = config.password(pw);
176 }
177
178 if let Some(params) = params_str {
180 for param in params.split('&') {
181 let (key, value) = param
182 .split_once('=')
183 .ok_or_else(|| Error::Config(format!("invalid parameter: {param}")))?;
184 let value = percent_decode(value)?;
185
186 match key {
187 "sslmode" => {
188 config = config.ssl_mode(match value.as_str() {
189 "disable" => SslMode::Disable,
190 "prefer" => SslMode::Prefer,
191 "require" => SslMode::Require,
192 "verify-ca" => SslMode::VerifyCa,
193 "verify-full" => SslMode::VerifyFull,
194 _ => return Err(Error::Config(format!("invalid sslmode: {value}"))),
195 });
196 }
197 "application_name" => {
198 config = config.application_name(value);
199 }
200 "connect_timeout" => {
201 let secs: u64 = value.parse().map_err(|_| {
202 Error::Config(format!("invalid connect_timeout: {value}"))
203 })?;
204 config = config.connect_timeout(Duration::from_secs(secs));
205 }
206 "statement_timeout" => {
207 let secs: u64 = value.parse().map_err(|_| {
208 Error::Config(format!("invalid statement_timeout: {value}"))
209 })?;
210 config = config.statement_timeout(Duration::from_secs(secs));
211 }
212 "target_session_attrs" => {
213 config = config.target_session_attrs(match value.as_str() {
214 "any" => TargetSessionAttrs::Any,
215 "read-write" => TargetSessionAttrs::ReadWrite,
216 "read-only" => TargetSessionAttrs::ReadOnly,
217 _ => {
218 return Err(Error::Config(format!(
219 "invalid target_session_attrs: {value}"
220 )))
221 }
222 });
223 }
224 "sslcert" => {
225 config = config.ssl_client_cert(PathBuf::from(value));
226 }
227 "sslkey" => {
228 config = config.ssl_client_key(PathBuf::from(value));
229 }
230 "ssldirect" | "sslnegotiation" => {
231 let direct = match value.as_str() {
232 "true" | "direct" => true,
233 "false" | "postgres" => false,
234 _ => return Err(Error::Config(format!("invalid {key}: {value}"))),
235 };
236 config = config.ssl_direct(direct);
237 }
238 "channel_binding" => {
239 config = config.channel_binding(match value.as_str() {
240 "prefer" => ChannelBinding::Prefer,
241 "require" => ChannelBinding::Require,
242 "disable" => ChannelBinding::Disable,
243 _ => {
244 return Err(Error::Config(format!(
245 "invalid channel_binding: {value}"
246 )))
247 }
248 });
249 }
250 "load_balance_hosts" => {
251 config = config.load_balance_hosts(match value.as_str() {
252 "disable" => LoadBalanceHosts::Disable,
253 "random" => LoadBalanceHosts::Random,
254 _ => {
255 return Err(Error::Config(format!(
256 "invalid load_balance_hosts: {value}"
257 )))
258 }
259 });
260 }
261 "host" => {
262 config = config.host_port(value, 5432);
264 }
265 _ => {
266 }
268 }
269 }
270 }
271
272 Ok(config.build())
273 }
274
275 pub fn builder() -> ConfigBuilder {
277 ConfigBuilder::new()
278 }
279
280 pub fn host(&self) -> &str {
284 self.hosts.first().map_or("localhost", |(h, _)| h.as_str())
285 }
286
287 pub fn port(&self) -> u16 {
289 self.hosts.first().map_or(5432, |(_, p)| *p)
290 }
291
292 pub fn hosts(&self) -> &[(String, u16)] {
294 &self.hosts
295 }
296
297 pub fn load_balance_hosts(&self) -> LoadBalanceHosts {
299 self.load_balance_hosts
300 }
301
302 pub fn target_session_attrs(&self) -> TargetSessionAttrs {
304 self.target_session_attrs
305 }
306
307 pub fn database(&self) -> &str {
308 &self.database
309 }
310
311 pub fn user(&self) -> &str {
312 &self.user
313 }
314
315 pub fn password(&self) -> Option<&str> {
316 self.password.as_deref()
317 }
318
319 pub fn ssl_mode(&self) -> SslMode {
320 self.ssl_mode
321 }
322
323 pub fn application_name(&self) -> Option<&str> {
324 self.application_name.as_deref()
325 }
326
327 pub fn connect_timeout(&self) -> Duration {
328 self.connect_timeout
329 }
330
331 pub fn statement_timeout(&self) -> Option<Duration> {
332 self.statement_timeout
333 }
334
335 pub fn ssl_client_cert(&self) -> Option<&std::path::Path> {
337 self.ssl_client_cert.as_deref()
338 }
339
340 pub fn ssl_client_key(&self) -> Option<&std::path::Path> {
342 self.ssl_client_key.as_deref()
343 }
344
345 pub fn ssl_direct(&self) -> bool {
347 self.ssl_direct
348 }
349
350 pub fn channel_binding(&self) -> ChannelBinding {
352 self.channel_binding
353 }
354
355 pub fn with_instrumentation(
358 mut self,
359 instr: std::sync::Arc<dyn crate::Instrumentation>,
360 ) -> Self {
361 self.instrumentation = Some(instr);
362 self
363 }
364}
365
366impl std::fmt::Debug for Config {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 f.debug_struct("Config")
369 .field("hosts", &self.hosts)
370 .field("database", &self.database)
371 .field("user", &self.user)
372 .field("password", &self.password.as_ref().map(|_| "..."))
373 .field("ssl_mode", &self.ssl_mode)
374 .field("application_name", &self.application_name)
375 .field("connect_timeout", &self.connect_timeout)
376 .field("statement_timeout", &self.statement_timeout)
377 .field("_keepalive", &self.keepalive)
378 .field("_keepalive_idle", &self.keepalive_idle)
379 .field("_extra_float_digits", &self.extra_float_digits)
380 .field("target_session_attrs", &self.target_session_attrs)
381 .field("load_balance_hosts", &self.load_balance_hosts)
382 .field("ssl_client_cert", &self.ssl_client_cert)
383 .field("ssl_client_key", &self.ssl_client_key)
384 .field("ssl_direct", &self.ssl_direct)
385 .field("channel_binding", &self.channel_binding)
386 .field(
387 "instrumentation",
388 &self.instrumentation.as_ref().map(|_| "..."),
389 )
390 .finish()
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct ConfigBuilder {
397 hosts: Vec<(String, u16)>,
398 default_port: u16,
399 database: String,
400 user: String,
401 password: Option<String>,
402 ssl_mode: SslMode,
403 application_name: Option<String>,
404 connect_timeout: Duration,
405 statement_timeout: Option<Duration>,
406 keepalive: Option<Duration>,
407 keepalive_idle: Option<Duration>,
408 target_session_attrs: TargetSessionAttrs,
409 extra_float_digits: Option<i32>,
410 load_balance_hosts: LoadBalanceHosts,
411 ssl_client_cert: Option<PathBuf>,
412 ssl_client_key: Option<PathBuf>,
413 ssl_direct: bool,
414 channel_binding: ChannelBinding,
415}
416
417impl ConfigBuilder {
418 fn new() -> Self {
419 Self {
420 hosts: Vec::new(),
421 default_port: 5432,
422 database: String::new(),
423 user: String::new(),
424 password: None,
425 ssl_mode: SslMode::default(),
426 application_name: None,
427 connect_timeout: Duration::from_secs(10),
428 statement_timeout: None,
429 keepalive: Some(Duration::from_secs(60)),
430 keepalive_idle: None,
431 target_session_attrs: TargetSessionAttrs::default(),
432 extra_float_digits: Some(3),
433 load_balance_hosts: LoadBalanceHosts::default(),
434 ssl_client_cert: None,
435 ssl_client_key: None,
436 ssl_direct: false,
437 channel_binding: ChannelBinding::default(),
438 }
439 }
440
441 pub fn host(mut self, host: impl Into<String>) -> Self {
443 self.hosts.push((host.into(), self.default_port));
444 self
445 }
446
447 pub fn host_port(mut self, host: impl Into<String>, port: u16) -> Self {
449 self.hosts.push((host.into(), port));
450 self
451 }
452
453 pub fn port(mut self, port: u16) -> Self {
456 let old_default = self.default_port;
457 self.default_port = port;
458 for (_, p) in &mut self.hosts {
459 if *p == old_default {
460 *p = port;
461 }
462 }
463 self
464 }
465
466 pub fn load_balance_hosts(mut self, strategy: LoadBalanceHosts) -> Self {
467 self.load_balance_hosts = strategy;
468 self
469 }
470
471 pub fn database(mut self, database: impl Into<String>) -> Self {
472 self.database = database.into();
473 self
474 }
475
476 pub fn user(mut self, user: impl Into<String>) -> Self {
477 self.user = user.into();
478 self
479 }
480
481 pub fn password(mut self, password: impl Into<String>) -> Self {
482 self.password = Some(password.into());
483 self
484 }
485
486 pub fn ssl_mode(mut self, ssl_mode: SslMode) -> Self {
487 self.ssl_mode = ssl_mode;
488 self
489 }
490
491 pub fn application_name(mut self, name: impl Into<String>) -> Self {
492 self.application_name = Some(name.into());
493 self
494 }
495
496 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
497 self.connect_timeout = timeout;
498 self
499 }
500
501 pub fn statement_timeout(mut self, timeout: Duration) -> Self {
502 self.statement_timeout = Some(timeout);
503 self
504 }
505
506 pub fn keepalive(mut self, interval: Duration) -> Self {
507 self.keepalive = Some(interval);
508 self
509 }
510
511 pub fn target_session_attrs(mut self, attrs: TargetSessionAttrs) -> Self {
512 self.target_session_attrs = attrs;
513 self
514 }
515
516 pub fn ssl_client_cert(mut self, path: impl Into<PathBuf>) -> Self {
518 self.ssl_client_cert = Some(path.into());
519 self
520 }
521
522 pub fn ssl_client_key(mut self, path: impl Into<PathBuf>) -> Self {
524 self.ssl_client_key = Some(path.into());
525 self
526 }
527
528 pub fn ssl_direct(mut self, direct: bool) -> Self {
530 self.ssl_direct = direct;
531 self
532 }
533
534 pub fn channel_binding(mut self, binding: ChannelBinding) -> Self {
536 self.channel_binding = binding;
537 self
538 }
539
540 pub fn build(self) -> Config {
542 let hosts = if self.hosts.is_empty() {
543 vec![("localhost".to_string(), self.default_port)]
544 } else {
545 self.hosts
546 };
547 Config {
548 hosts,
549 database: self.database,
550 user: self.user,
551 password: self.password,
552 ssl_mode: self.ssl_mode,
553 application_name: self.application_name,
554 connect_timeout: self.connect_timeout,
555 statement_timeout: self.statement_timeout,
556 keepalive: self.keepalive,
557 keepalive_idle: self.keepalive_idle,
558 target_session_attrs: self.target_session_attrs,
559 extra_float_digits: self.extra_float_digits,
560 load_balance_hosts: self.load_balance_hosts,
561 ssl_client_cert: self.ssl_client_cert,
562 ssl_client_key: self.ssl_client_key,
563 ssl_direct: self.ssl_direct,
564 channel_binding: self.channel_binding,
565 instrumentation: None,
566 }
567 }
568}
569
570fn percent_decode(s: &str) -> Result<String> {
572 let mut result = String::with_capacity(s.len());
573 let mut chars = s.as_bytes().iter();
574
575 while let Some(&b) = chars.next() {
576 if b == b'%' {
577 let hi = chars
578 .next()
579 .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
580 let lo = chars
581 .next()
582 .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
583 let byte = hex_digit(*hi)? << 4 | hex_digit(*lo)?;
584 result.push(byte as char);
585 } else {
586 result.push(b as char);
587 }
588 }
589
590 Ok(result)
591}
592
593fn hex_digit(b: u8) -> Result<u8> {
594 match b {
595 b'0'..=b'9' => Ok(b - b'0'),
596 b'a'..=b'f' => Ok(b - b'a' + 10),
597 b'A'..=b'F' => Ok(b - b'A' + 10),
598 _ => Err(Error::Config(format!("invalid hex digit: {}", b as char))),
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn config_builder_build_populates_all_fields() {
608 let cfg = ConfigBuilder::new()
609 .host("localhost".to_string())
610 .port(5432)
611 .database("test".to_string())
612 .user("postgres".to_string())
613 .password("secret".to_string())
614 .application_name("test_app".to_string())
615 .connect_timeout(std::time::Duration::from_secs(10))
616 .channel_binding(ChannelBinding::Prefer)
617 .build();
618 assert_eq!(cfg.user, "postgres");
619 assert_eq!(cfg.database, "test");
620 assert_eq!(cfg.application_name.as_deref(), Some("test_app"));
621 assert_eq!(cfg.channel_binding, ChannelBinding::Prefer);
622 assert_eq!(cfg.keepalive, Some(std::time::Duration::from_secs(60)));
624 assert!(cfg.keepalive_idle.is_none());
625 assert_eq!(cfg.extra_float_digits, Some(3));
627 assert!(cfg.instrumentation.is_none());
628 }
629
630 #[test]
631 fn channel_binding_accessor() {
632 let cfg = ConfigBuilder::new()
633 .channel_binding(ChannelBinding::Require)
634 .build();
635 assert_eq!(cfg.channel_binding(), ChannelBinding::Require);
636 }
637
638 #[test]
639 fn with_instrumentation_sets_field() {
640 struct NoOp;
641 impl crate::Instrumentation for NoOp {
642 fn on_event(&self, _: &crate::Event<'_>) {}
643 }
644 let cfg = ConfigBuilder::new()
645 .build()
646 .with_instrumentation(std::sync::Arc::new(NoOp));
647 assert!(cfg.instrumentation.is_some());
648 }
649
650 #[test]
651 fn debug_redacts_password() {
652 let cfg = ConfigBuilder::new()
653 .password("super_secret".to_string())
654 .build();
655 let s = format!("{cfg:?}");
656 assert!(!s.contains("super_secret"));
657 assert!(s.contains("password"));
658 }
659}