1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::error::{Error, Result};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum SslMode {
9 Disable,
11 #[default]
13 Prefer,
14 Require,
16 VerifyCa,
18 VerifyFull,
20}
21
22#[derive(Debug, Clone)]
44pub struct Config {
45 pub(crate) hosts: Vec<(String, u16)>,
46 pub(crate) database: String,
47 pub(crate) user: String,
48 pub(crate) password: Option<String>,
49 pub(crate) ssl_mode: SslMode,
50 pub(crate) application_name: Option<String>,
51 pub(crate) connect_timeout: Duration,
52 pub(crate) statement_timeout: Option<Duration>,
53 pub(crate) _keepalive: Option<Duration>,
54 pub(crate) _keepalive_idle: Option<Duration>,
55 pub(crate) target_session_attrs: TargetSessionAttrs,
56 pub(crate) _extra_float_digits: Option<i32>,
57 pub(crate) load_balance_hosts: LoadBalanceHosts,
58 pub(crate) ssl_client_cert: Option<std::path::PathBuf>,
60 pub(crate) ssl_client_key: Option<std::path::PathBuf>,
62 pub(crate) ssl_direct: bool,
64 pub(crate) channel_binding: ChannelBinding,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum ChannelBinding {
71 #[default]
73 Prefer,
74 Require,
76 Disable,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
82pub enum TargetSessionAttrs {
83 #[default]
85 Any,
86 ReadWrite,
88 ReadOnly,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
94pub enum LoadBalanceHosts {
95 #[default]
97 Disable,
98 Random,
100}
101
102impl Config {
103 pub fn parse(s: &str) -> Result<Self> {
109 let s = s.trim();
110
111 let without_scheme = s
112 .strip_prefix("postgres://")
113 .or_else(|| s.strip_prefix("postgresql://"))
114 .ok_or_else(|| {
115 Error::Config(
116 "connection string must start with postgres:// or postgresql://".into(),
117 )
118 })?;
119
120 let (userinfo, rest) = match without_scheme.split_once('@') {
121 Some((ui, rest)) => (Some(ui), rest),
122 None => (None, without_scheme),
123 };
124
125 let (user, password) = match userinfo {
126 Some(ui) => match ui.split_once(':') {
127 Some((u, p)) => (percent_decode(u)?, Some(percent_decode(p)?)),
128 None => (percent_decode(ui)?, None),
129 },
130 None => (String::new(), None),
131 };
132
133 let (hostport, db_and_params) = match rest.split_once('/') {
135 Some((hp, rest)) => (hp, Some(rest)),
136 None => (rest, None),
137 };
138
139 let mut hosts: Vec<(String, u16)> = Vec::new();
141 if hostport.is_empty() {
142 } else {
144 for entry in hostport.split(',') {
145 let (h, p) = match entry.rsplit_once(':') {
146 Some((h, p)) => {
147 let port: u16 = p
148 .parse()
149 .map_err(|_| Error::Config(format!("invalid port: {p}")))?;
150 (h.to_string(), port)
151 }
152 None => (entry.to_string(), 5432),
153 };
154 hosts.push((h, p));
155 }
156 }
157
158 let (database, params_str) = match db_and_params {
159 Some(dp) => match dp.split_once('?') {
160 Some((db, params)) => (percent_decode(db)?, Some(params.to_string())),
161 None => (percent_decode(dp)?, None),
162 },
163 None => (String::new(), None),
164 };
165
166 let mut config = ConfigBuilder::new();
167 for (h, p) in &hosts {
168 config = config.host_port(h.clone(), *p);
169 }
170 config = config.database(database).user(user);
171
172 if let Some(pw) = password {
173 config = config.password(pw);
174 }
175
176 if let Some(params) = params_str {
178 for param in params.split('&') {
179 let (key, value) = param
180 .split_once('=')
181 .ok_or_else(|| Error::Config(format!("invalid parameter: {param}")))?;
182 let value = percent_decode(value)?;
183
184 match key {
185 "sslmode" => {
186 config = config.ssl_mode(match value.as_str() {
187 "disable" => SslMode::Disable,
188 "prefer" => SslMode::Prefer,
189 "require" => SslMode::Require,
190 "verify-ca" => SslMode::VerifyCa,
191 "verify-full" => SslMode::VerifyFull,
192 _ => return Err(Error::Config(format!("invalid sslmode: {value}"))),
193 });
194 }
195 "application_name" => {
196 config = config.application_name(value);
197 }
198 "connect_timeout" => {
199 let secs: u64 = value.parse().map_err(|_| {
200 Error::Config(format!("invalid connect_timeout: {value}"))
201 })?;
202 config = config.connect_timeout(Duration::from_secs(secs));
203 }
204 "statement_timeout" => {
205 let secs: u64 = value.parse().map_err(|_| {
206 Error::Config(format!("invalid statement_timeout: {value}"))
207 })?;
208 config = config.statement_timeout(Duration::from_secs(secs));
209 }
210 "target_session_attrs" => {
211 config = config.target_session_attrs(match value.as_str() {
212 "any" => TargetSessionAttrs::Any,
213 "read-write" => TargetSessionAttrs::ReadWrite,
214 "read-only" => TargetSessionAttrs::ReadOnly,
215 _ => {
216 return Err(Error::Config(format!(
217 "invalid target_session_attrs: {value}"
218 )))
219 }
220 });
221 }
222 "sslcert" => {
223 config = config.ssl_client_cert(PathBuf::from(value));
224 }
225 "sslkey" => {
226 config = config.ssl_client_key(PathBuf::from(value));
227 }
228 "ssldirect" | "sslnegotiation" => {
229 let direct = match value.as_str() {
230 "true" | "direct" => true,
231 "false" | "postgres" => false,
232 _ => return Err(Error::Config(format!("invalid {key}: {value}"))),
233 };
234 config = config.ssl_direct(direct);
235 }
236 "channel_binding" => {
237 config = config.channel_binding(match value.as_str() {
238 "prefer" => ChannelBinding::Prefer,
239 "require" => ChannelBinding::Require,
240 "disable" => ChannelBinding::Disable,
241 _ => {
242 return Err(Error::Config(format!(
243 "invalid channel_binding: {value}"
244 )))
245 }
246 });
247 }
248 "load_balance_hosts" => {
249 config = config.load_balance_hosts(match value.as_str() {
250 "disable" => LoadBalanceHosts::Disable,
251 "random" => LoadBalanceHosts::Random,
252 _ => {
253 return Err(Error::Config(format!(
254 "invalid load_balance_hosts: {value}"
255 )))
256 }
257 });
258 }
259 "host" => {
260 config = config.host_port(value, 5432);
262 }
263 _ => {
264 }
266 }
267 }
268 }
269
270 Ok(config.build())
271 }
272
273 pub fn builder() -> ConfigBuilder {
275 ConfigBuilder::new()
276 }
277
278 pub fn host(&self) -> &str {
282 self.hosts.first().map_or("localhost", |(h, _)| h.as_str())
283 }
284
285 pub fn port(&self) -> u16 {
287 self.hosts.first().map_or(5432, |(_, p)| *p)
288 }
289
290 pub fn hosts(&self) -> &[(String, u16)] {
292 &self.hosts
293 }
294
295 pub fn load_balance_hosts(&self) -> LoadBalanceHosts {
297 self.load_balance_hosts
298 }
299
300 pub fn target_session_attrs(&self) -> TargetSessionAttrs {
302 self.target_session_attrs
303 }
304
305 pub fn database(&self) -> &str {
306 &self.database
307 }
308
309 pub fn user(&self) -> &str {
310 &self.user
311 }
312
313 pub fn password(&self) -> Option<&str> {
314 self.password.as_deref()
315 }
316
317 pub fn ssl_mode(&self) -> SslMode {
318 self.ssl_mode
319 }
320
321 pub fn application_name(&self) -> Option<&str> {
322 self.application_name.as_deref()
323 }
324
325 pub fn connect_timeout(&self) -> Duration {
326 self.connect_timeout
327 }
328
329 pub fn statement_timeout(&self) -> Option<Duration> {
330 self.statement_timeout
331 }
332
333 pub fn ssl_client_cert(&self) -> Option<&std::path::Path> {
335 self.ssl_client_cert.as_deref()
336 }
337
338 pub fn ssl_client_key(&self) -> Option<&std::path::Path> {
340 self.ssl_client_key.as_deref()
341 }
342
343 pub fn ssl_direct(&self) -> bool {
345 self.ssl_direct
346 }
347
348 pub fn channel_binding(&self) -> ChannelBinding {
350 self.channel_binding
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct ConfigBuilder {
357 hosts: Vec<(String, u16)>,
358 default_port: u16,
359 database: String,
360 user: String,
361 password: Option<String>,
362 ssl_mode: SslMode,
363 application_name: Option<String>,
364 connect_timeout: Duration,
365 statement_timeout: Option<Duration>,
366 keepalive: Option<Duration>,
367 keepalive_idle: Option<Duration>,
368 target_session_attrs: TargetSessionAttrs,
369 extra_float_digits: Option<i32>,
370 load_balance_hosts: LoadBalanceHosts,
371 ssl_client_cert: Option<PathBuf>,
372 ssl_client_key: Option<PathBuf>,
373 ssl_direct: bool,
374 channel_binding: ChannelBinding,
375}
376
377impl ConfigBuilder {
378 fn new() -> Self {
379 Self {
380 hosts: Vec::new(),
381 default_port: 5432,
382 database: String::new(),
383 user: String::new(),
384 password: None,
385 ssl_mode: SslMode::default(),
386 application_name: None,
387 connect_timeout: Duration::from_secs(10),
388 statement_timeout: None,
389 keepalive: Some(Duration::from_secs(60)),
390 keepalive_idle: None,
391 target_session_attrs: TargetSessionAttrs::default(),
392 extra_float_digits: Some(3),
393 load_balance_hosts: LoadBalanceHosts::default(),
394 ssl_client_cert: None,
395 ssl_client_key: None,
396 ssl_direct: false,
397 channel_binding: ChannelBinding::default(),
398 }
399 }
400
401 pub fn host(mut self, host: impl Into<String>) -> Self {
403 self.hosts.push((host.into(), self.default_port));
404 self
405 }
406
407 pub fn host_port(mut self, host: impl Into<String>, port: u16) -> Self {
409 self.hosts.push((host.into(), port));
410 self
411 }
412
413 pub fn port(mut self, port: u16) -> Self {
416 let old_default = self.default_port;
417 self.default_port = port;
418 for (_, p) in &mut self.hosts {
419 if *p == old_default {
420 *p = port;
421 }
422 }
423 self
424 }
425
426 pub fn load_balance_hosts(mut self, strategy: LoadBalanceHosts) -> Self {
427 self.load_balance_hosts = strategy;
428 self
429 }
430
431 pub fn database(mut self, database: impl Into<String>) -> Self {
432 self.database = database.into();
433 self
434 }
435
436 pub fn user(mut self, user: impl Into<String>) -> Self {
437 self.user = user.into();
438 self
439 }
440
441 pub fn password(mut self, password: impl Into<String>) -> Self {
442 self.password = Some(password.into());
443 self
444 }
445
446 pub fn ssl_mode(mut self, ssl_mode: SslMode) -> Self {
447 self.ssl_mode = ssl_mode;
448 self
449 }
450
451 pub fn application_name(mut self, name: impl Into<String>) -> Self {
452 self.application_name = Some(name.into());
453 self
454 }
455
456 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
457 self.connect_timeout = timeout;
458 self
459 }
460
461 pub fn statement_timeout(mut self, timeout: Duration) -> Self {
462 self.statement_timeout = Some(timeout);
463 self
464 }
465
466 pub fn keepalive(mut self, interval: Duration) -> Self {
467 self.keepalive = Some(interval);
468 self
469 }
470
471 pub fn target_session_attrs(mut self, attrs: TargetSessionAttrs) -> Self {
472 self.target_session_attrs = attrs;
473 self
474 }
475
476 pub fn ssl_client_cert(mut self, path: impl Into<PathBuf>) -> Self {
478 self.ssl_client_cert = Some(path.into());
479 self
480 }
481
482 pub fn ssl_client_key(mut self, path: impl Into<PathBuf>) -> Self {
484 self.ssl_client_key = Some(path.into());
485 self
486 }
487
488 pub fn ssl_direct(mut self, direct: bool) -> Self {
490 self.ssl_direct = direct;
491 self
492 }
493
494 pub fn channel_binding(mut self, binding: ChannelBinding) -> Self {
496 self.channel_binding = binding;
497 self
498 }
499
500 pub fn build(self) -> Config {
502 let hosts = if self.hosts.is_empty() {
503 vec![("localhost".to_string(), self.default_port)]
504 } else {
505 self.hosts
506 };
507 Config {
508 hosts,
509 database: self.database,
510 user: self.user,
511 password: self.password,
512 ssl_mode: self.ssl_mode,
513 application_name: self.application_name,
514 connect_timeout: self.connect_timeout,
515 statement_timeout: self.statement_timeout,
516 _keepalive: self.keepalive,
517 _keepalive_idle: self.keepalive_idle,
518 target_session_attrs: self.target_session_attrs,
519 _extra_float_digits: self.extra_float_digits,
520 load_balance_hosts: self.load_balance_hosts,
521 ssl_client_cert: self.ssl_client_cert,
522 ssl_client_key: self.ssl_client_key,
523 ssl_direct: self.ssl_direct,
524 channel_binding: self.channel_binding,
525 }
526 }
527}
528
529fn percent_decode(s: &str) -> Result<String> {
531 let mut result = String::with_capacity(s.len());
532 let mut chars = s.as_bytes().iter();
533
534 while let Some(&b) = chars.next() {
535 if b == b'%' {
536 let hi = chars
537 .next()
538 .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
539 let lo = chars
540 .next()
541 .ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
542 let byte = hex_digit(*hi)? << 4 | hex_digit(*lo)?;
543 result.push(byte as char);
544 } else {
545 result.push(b as char);
546 }
547 }
548
549 Ok(result)
550}
551
552fn hex_digit(b: u8) -> Result<u8> {
553 match b {
554 b'0'..=b'9' => Ok(b - b'0'),
555 b'a'..=b'f' => Ok(b - b'a' + 10),
556 b'A'..=b'F' => Ok(b - b'A' + 10),
557 _ => Err(Error::Config(format!("invalid hex digit: {}", b as char))),
558 }
559}