1use core::{fmt, iter, mem, str};
4
5use std::{
6 borrow::Cow,
7 path::{Path, PathBuf},
8};
9
10use super::{error::Error, session::TargetSessionAttrs};
11
12#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum SslMode {
15 #[cfg_attr(not(feature = "tls"), default)]
17 Disable,
18 #[cfg_attr(feature = "tls", default)]
20 Prefer,
21 Require,
23}
24
25#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum SslNegotiation {
29 #[default]
31 Postgres,
32 Direct,
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
38pub enum Host {
39 Tcp(Box<str>),
41 Quic(Box<str>),
42 Unix(PathBuf),
44}
45
46#[derive(Clone, Eq, PartialEq)]
47pub struct Config {
48 pub(crate) user: Option<Box<str>>,
49 pub(crate) password: Option<Box<[u8]>>,
50 pub(crate) dbname: Option<Box<str>>,
51 pub(crate) options: Option<Box<str>>,
52 pub(crate) application_name: Option<Box<str>>,
53 pub(crate) ssl_mode: SslMode,
54 pub(crate) ssl_negotiation: SslNegotiation,
55 pub(crate) host: Vec<Host>,
56 pub(crate) port: Vec<u16>,
57 target_session_attrs: TargetSessionAttrs,
58 tls_server_end_point: Option<Box<[u8]>>,
59}
60
61impl Default for Config {
62 fn default() -> Config {
63 Config::new()
64 }
65}
66
67impl Config {
68 pub fn new() -> Config {
70 Config {
71 user: None,
72 password: None,
73 dbname: None,
74 options: None,
75 application_name: None,
76 ssl_mode: SslMode::default(),
77 ssl_negotiation: SslNegotiation::Postgres,
78 host: Vec::new(),
79 port: Vec::new(),
80 target_session_attrs: TargetSessionAttrs::Any,
81 tls_server_end_point: None,
82 }
83 }
84
85 pub fn user(&mut self, user: &str) -> &mut Config {
89 self.user = Some(Box::from(user));
90 self
91 }
92
93 pub fn get_user(&self) -> Option<&str> {
96 self.user.as_deref()
97 }
98
99 pub fn password<T>(&mut self, password: T) -> &mut Config
101 where
102 T: AsRef<[u8]>,
103 {
104 self.password = Some(Box::from(password.as_ref()));
105 self
106 }
107
108 pub fn get_password(&self) -> Option<&[u8]> {
111 self.password.as_deref()
112 }
113
114 pub fn dbname(&mut self, dbname: &str) -> &mut Config {
118 self.dbname = Some(Box::from(dbname));
119 self
120 }
121
122 pub fn get_dbname(&self) -> Option<&str> {
125 self.dbname.as_deref()
126 }
127
128 pub fn options(&mut self, options: &str) -> &mut Config {
130 self.options = Some(Box::from(options));
131 self
132 }
133
134 pub fn get_options(&self) -> Option<&str> {
137 self.options.as_deref()
138 }
139
140 pub fn application_name(&mut self, application_name: &str) -> &mut Config {
142 self.application_name = Some(Box::from(application_name));
143 self
144 }
145
146 pub fn get_application_name(&self) -> Option<&str> {
149 self.application_name.as_deref()
150 }
151
152 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
156 self.ssl_mode = ssl_mode;
157 self
158 }
159
160 pub fn get_ssl_mode(&self) -> SslMode {
162 self.ssl_mode
163 }
164
165 pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
169 self.ssl_negotiation = ssl_negotiation;
170 self
171 }
172
173 pub fn get_ssl_negotiation(&self) -> SslNegotiation {
175 self.ssl_negotiation
176 }
177
178 pub fn host(&mut self, host: &str) -> &mut Config {
179 if host.starts_with('/') {
180 return self.host_path(host);
181 }
182
183 let host = Host::Tcp(Box::from(host));
184
185 self.host.push(host);
186 self
187 }
188
189 pub fn host_path<T>(&mut self, host: T) -> &mut Config
193 where
194 T: AsRef<Path>,
195 {
196 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
197 self
198 }
199
200 pub fn get_hosts(&self) -> &[Host] {
202 &self.host
203 }
204
205 pub fn port(&mut self, port: u16) -> &mut Config {
211 self.port.push(port);
212 self
213 }
214
215 pub fn get_ports(&self) -> &[u16] {
217 &self.port
218 }
219
220 pub fn target_session_attrs(&mut self, target_session_attrs: TargetSessionAttrs) -> &mut Config {
225 self.target_session_attrs = target_session_attrs;
226 self
227 }
228
229 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
231 self.target_session_attrs
232 }
233
234 pub fn tls_server_end_point(&mut self, tls_server_end_point: impl AsRef<[u8]>) -> &mut Self {
303 self.tls_server_end_point = Some(Box::from(tls_server_end_point.as_ref()));
304 self
305 }
306
307 pub fn get_tls_server_end_point(&self) -> Option<&[u8]> {
308 self.tls_server_end_point.as_deref()
309 }
310
311 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
312 match key {
313 "user" => {
314 self.user(value);
315 }
316 "password" => {
317 self.password(value);
318 }
319 "dbname" => {
320 self.dbname(value);
321 }
322 "options" => {
323 self.options(value);
324 }
325 "application_name" => {
326 self.application_name(value);
327 }
328 "sslmode" => {
329 let mode = match value {
330 "disable" => SslMode::Disable,
331 "prefer" => SslMode::Prefer,
332 "require" => SslMode::Require,
333 _ => return Err(Error::todo()),
334 };
335 self.ssl_mode(mode);
336 }
337 "sslnegotiation" => {
338 let mode = match value {
339 "postgres" => SslNegotiation::Postgres,
340 "direct" => SslNegotiation::Direct,
341 _ => return Err(Error::todo()),
342 };
343 self.ssl_negotiation(mode);
344 }
345 "host" => {
346 for host in value.split(',') {
347 self.host(host);
348 }
349 }
350 "port" => {
351 for port in value.split(',') {
352 let port = if port.is_empty() {
353 5432
354 } else {
355 port.parse().map_err(|_| Error::todo())?
356 };
357 self.port(port);
358 }
359 }
360 "target_session_attrs" => {
361 let target_session_attrs = match value {
362 "any" => TargetSessionAttrs::Any,
363 "read-write" => TargetSessionAttrs::ReadWrite,
364 "read-only" => TargetSessionAttrs::ReadOnly,
365 _ => return Err(Error::todo()),
366 };
367 self.target_session_attrs(target_session_attrs);
368 }
369 _ => {
370 return Err(Error::todo());
371 }
372 }
373
374 Ok(())
375 }
376}
377
378impl TryFrom<String> for Config {
379 type Error = Error;
380
381 fn try_from(s: String) -> Result<Self, Self::Error> {
382 Self::try_from(s.as_str())
383 }
384}
385
386impl TryFrom<&str> for Config {
387 type Error = Error;
388
389 fn try_from(s: &str) -> Result<Self, Self::Error> {
390 match UrlParser::parse(s)? {
391 Some(config) => Ok(config),
392 None => Parser::parse(s),
393 }
394 }
395}
396
397impl fmt::Debug for Config {
399 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400 struct Redaction {}
401 impl fmt::Debug for Redaction {
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 write!(f, "_")
404 }
405 }
406
407 f.debug_struct("Config")
408 .field("user", &self.user)
409 .field("password", &self.password.as_ref().map(|_| Redaction {}))
410 .field("dbname", &self.dbname)
411 .field("options", &self.options)
412 .field("application_name", &self.application_name)
413 .field("host", &self.host)
414 .field("port", &self.port)
415 .field("target_session_attrs", &self.target_session_attrs)
416 .finish()
417 }
418}
419
420struct Parser<'a> {
421 s: &'a str,
422 it: iter::Peekable<str::CharIndices<'a>>,
423}
424
425impl<'a> Parser<'a> {
426 fn parse(s: &'a str) -> Result<Config, Error> {
427 let mut parser = Parser {
428 s,
429 it: s.char_indices().peekable(),
430 };
431
432 let mut config = Config::new();
433
434 while let Some((key, value)) = parser.parameter()? {
435 config.param(key, &value)?;
436 }
437
438 Ok(config)
439 }
440
441 fn skip_ws(&mut self) {
442 self.take_while(char::is_whitespace);
443 }
444
445 fn take_while<F>(&mut self, f: F) -> &'a str
446 where
447 F: Fn(char) -> bool,
448 {
449 let start = match self.it.peek() {
450 Some(&(i, _)) => i,
451 None => return "",
452 };
453
454 loop {
455 match self.it.peek() {
456 Some(&(_, c)) if f(c) => {
457 self.it.next();
458 }
459 Some(&(i, _)) => return &self.s[start..i],
460 None => return &self.s[start..],
461 }
462 }
463 }
464
465 fn eat(&mut self, target: char) -> Result<(), Error> {
466 match self.it.next() {
467 Some((_, c)) if c == target => Ok(()),
468 Some((i, c)) => {
469 let _m = format!("unexpected character at byte {i}: expected `{target}` but got `{c}`");
470 Err(Error::todo())
471 }
472 None => Err(Error::todo()),
473 }
474 }
475
476 fn eat_if(&mut self, target: char) -> bool {
477 match self.it.peek() {
478 Some(&(_, c)) if c == target => {
479 self.it.next();
480 true
481 }
482 _ => false,
483 }
484 }
485
486 fn keyword(&mut self) -> Option<&'a str> {
487 let s = self.take_while(|c| match c {
488 c if c.is_whitespace() => false,
489 '=' => false,
490 _ => true,
491 });
492
493 if s.is_empty() { None } else { Some(s) }
494 }
495
496 fn value(&mut self) -> Result<String, Error> {
497 let value = if self.eat_if('\'') {
498 let value = self.quoted_value()?;
499 self.eat('\'')?;
500 value
501 } else {
502 self.simple_value()?
503 };
504
505 Ok(value)
506 }
507
508 fn simple_value(&mut self) -> Result<String, Error> {
509 let mut value = String::new();
510
511 while let Some(&(_, c)) = self.it.peek() {
512 if c.is_whitespace() {
513 break;
514 }
515
516 self.it.next();
517 if c == '\\' {
518 if let Some((_, c2)) = self.it.next() {
519 value.push(c2);
520 }
521 } else {
522 value.push(c);
523 }
524 }
525
526 if value.is_empty() {
527 return Err(Error::todo());
528 }
529
530 Ok(value)
531 }
532
533 fn quoted_value(&mut self) -> Result<String, Error> {
534 let mut value = String::new();
535
536 while let Some(&(_, c)) = self.it.peek() {
537 if c == '\'' {
538 return Ok(value);
539 }
540
541 self.it.next();
542 if c == '\\' {
543 if let Some((_, c2)) = self.it.next() {
544 value.push(c2);
545 }
546 } else {
547 value.push(c);
548 }
549 }
550
551 Err(Error::todo())
552 }
553
554 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
555 self.skip_ws();
556 let keyword = match self.keyword() {
557 Some(keyword) => keyword,
558 None => return Ok(None),
559 };
560 self.skip_ws();
561 self.eat('=')?;
562 self.skip_ws();
563 let value = self.value()?;
564
565 Ok(Some((keyword, value)))
566 }
567}
568
569struct UrlParser<'a> {
571 s: &'a str,
572 config: Config,
573}
574
575impl<'a> UrlParser<'a> {
576 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
577 let s = match Self::remove_url_prefix(s) {
578 Some(s) => s,
579 None => return Ok(None),
580 };
581
582 let mut parser = UrlParser {
583 s,
584 config: Config::new(),
585 };
586
587 parser.parse_credentials()?;
588 parser.parse_host()?;
589 parser.parse_path()?;
590 parser.parse_params()?;
591
592 Ok(Some(parser.config))
593 }
594
595 fn remove_url_prefix(s: &str) -> Option<&str> {
596 for prefix in &["postgres://", "postgresql://"] {
597 if let Some(stripped) = s.strip_prefix(prefix) {
598 return Some(stripped);
599 }
600 }
601
602 None
603 }
604
605 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
606 match self.s.find(end) {
607 Some(pos) => {
608 let (head, tail) = self.s.split_at(pos);
609 self.s = tail;
610 Some(head)
611 }
612 None => None,
613 }
614 }
615
616 fn take_all(&mut self) -> &'a str {
617 mem::take(&mut self.s)
618 }
619
620 fn eat_byte(&mut self) {
621 self.s = &self.s[1..];
622 }
623
624 fn parse_credentials(&mut self) -> Result<(), Error> {
625 let creds = match self.take_until(&['@']) {
626 Some(creds) => creds,
627 None => return Ok(()),
628 };
629 self.eat_byte();
630
631 let mut it = creds.splitn(2, ':');
632 let user = self.decode(it.next().unwrap())?;
633 self.config.user(&user);
634
635 if let Some(password) = it.next() {
636 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
637 self.config.password(password);
638 }
639
640 Ok(())
641 }
642
643 fn parse_host(&mut self) -> Result<(), Error> {
644 let host = match self.take_until(&['/', '?']) {
645 Some(host) => host,
646 None => self.take_all(),
647 };
648
649 if host.is_empty() {
650 return Ok(());
651 }
652
653 for chunk in host.split(',') {
654 let (host, port) = if chunk.starts_with('[') {
655 let idx = match chunk.find(']') {
656 Some(idx) => idx,
657 None => return Err(Error::todo()),
658 };
659
660 let host = &chunk[1..idx];
661 let remaining = &chunk[idx + 1..];
662 let port = if let Some(port) = remaining.strip_prefix(':') {
663 Some(port)
664 } else if remaining.is_empty() {
665 None
666 } else {
667 return Err(Error::todo());
668 };
669
670 (host, port)
671 } else {
672 let mut it = chunk.splitn(2, ':');
673 (it.next().unwrap(), it.next())
674 };
675
676 self.host_param(host)?;
677 let port = self.decode(port.unwrap_or("5432"))?;
678 self.config.param("port", &port)?;
679 }
680
681 Ok(())
682 }
683
684 fn parse_path(&mut self) -> Result<(), Error> {
685 if !self.s.starts_with('/') {
686 return Ok(());
687 }
688 self.eat_byte();
689
690 let dbname = match self.take_until(&['?']) {
691 Some(dbname) => dbname,
692 None => self.take_all(),
693 };
694
695 if !dbname.is_empty() {
696 self.config.dbname(&self.decode(dbname)?);
697 }
698
699 Ok(())
700 }
701
702 fn parse_params(&mut self) -> Result<(), Error> {
703 if !self.s.starts_with('?') {
704 return Ok(());
705 }
706 self.eat_byte();
707
708 while !self.s.is_empty() {
709 let key = match self.take_until(&['=']) {
710 Some(key) => self.decode(key)?,
711 None => return Err(Error::todo()),
712 };
713 self.eat_byte();
714
715 let value = match self.take_until(&['&']) {
716 Some(value) => {
717 self.eat_byte();
718 value
719 }
720 None => self.take_all(),
721 };
722
723 if key == "host" {
724 self.host_param(value)?;
725 } else {
726 let value = self.decode(value)?;
727 self.config.param(&key, &value)?;
728 }
729 }
730
731 Ok(())
732 }
733
734 fn host_param(&mut self, s: &str) -> Result<(), Error> {
735 let s = self.decode(s)?;
736 self.config.param("host", &s)
737 }
738
739 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
740 percent_encoding::percent_decode(s.as_bytes())
741 .decode_utf8()
742 .map_err(|_| Error::todo())
743 }
744}