1#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect::yb_connect;
6use crate::connect_raw::connect_raw;
7#[cfg(not(target_arch = "wasm32"))]
8use crate::keepalive::KeepaliveConfig;
9#[cfg(feature = "runtime")]
10use crate::tls::MakeTlsConnect;
11use crate::tls::TlsConnect;
12#[cfg(feature = "runtime")]
13use crate::Socket;
14use crate::{Client, Connection, Error};
15use std::borrow::Cow;
16use std::collections::HashMap;
17#[cfg(unix)]
18use std::ffi::OsStr;
19use std::net::IpAddr;
20use std::ops::Deref;
21#[cfg(unix)]
22use std::os::unix::ffi::OsStrExt;
23#[cfg(unix)]
24use std::path::{Path, PathBuf};
25use std::str;
26use std::str::FromStr;
27use std::time::Duration;
28use std::{error, fmt, iter, mem};
29use tokio::io::{AsyncRead, AsyncWrite};
30
31#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33#[non_exhaustive]
34pub enum TargetSessionAttrs {
35 Any,
37 ReadWrite,
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum SslMode {
45 Disable,
47 Prefer,
49 Require,
51}
52
53#[derive(Debug, Copy, Clone, PartialEq, Eq)]
55#[non_exhaustive]
56pub enum ChannelBinding {
57 Disable,
59 Prefer,
61 Require,
63}
64
65#[derive(Debug, Copy, Clone, PartialEq, Eq)]
67#[non_exhaustive]
68pub enum LoadBalanceHosts {
69 Disable,
71 Random,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum Host {
78 Tcp(String),
80 #[cfg(unix)]
84 Unix(PathBuf),
85}
86
87#[derive(Clone, PartialEq, Eq)]
211pub struct Config {
212 pub(crate) user: Option<String>,
213 pub(crate) password: Option<Vec<u8>>,
214 pub(crate) dbname: Option<String>,
215 pub(crate) options: Option<String>,
216 pub(crate) application_name: Option<String>,
217 pub(crate) ssl_mode: SslMode,
218 pub(crate) host: Vec<Host>,
219 pub(crate) hostaddr: Vec<IpAddr>,
220 pub(crate) port: Vec<u16>,
221 pub(crate) connect_timeout: Option<Duration>,
222 pub(crate) tcp_user_timeout: Option<Duration>,
223 pub(crate) keepalives: bool,
224 #[cfg(not(target_arch = "wasm32"))]
225 pub(crate) keepalive_config: KeepaliveConfig,
226 pub(crate) target_session_attrs: TargetSessionAttrs,
227 pub(crate) channel_binding: ChannelBinding,
228 pub(crate) load_balance_hosts: LoadBalanceHosts,
229 pub(crate) load_balance: String,
231 pub(crate) topology_keys: HashMap<i64, Vec<String>>,
232 pub(crate) yb_servers_refresh_interval: Duration,
233 pub(crate) fallback_to_topology_keys_only: bool,
234 pub(crate) failed_host_reconnect_delay_secs: Duration,
235}
236
237impl Default for Config {
238 fn default() -> Config {
239 Config::new()
240 }
241}
242
243impl Config {
244 pub fn new() -> Config {
246 Config {
247 user: None,
248 password: None,
249 dbname: None,
250 options: None,
251 application_name: None,
252 ssl_mode: SslMode::Prefer,
253 host: vec![],
254 hostaddr: vec![],
255 port: vec![],
256 connect_timeout: None,
257 tcp_user_timeout: None,
258 keepalives: true,
259 #[cfg(not(target_arch = "wasm32"))]
260 keepalive_config: KeepaliveConfig {
261 idle: Duration::from_secs(2 * 60 * 60),
262 interval: None,
263 retries: None,
264 },
265 target_session_attrs: TargetSessionAttrs::Any,
266 channel_binding: ChannelBinding::Prefer,
267 load_balance_hosts: LoadBalanceHosts::Disable,
268 load_balance: String::from("false"),
269 topology_keys: HashMap::new(),
270 yb_servers_refresh_interval: Duration::new(300, 0),
271 fallback_to_topology_keys_only: false,
272 failed_host_reconnect_delay_secs: Duration::new(5, 0),
273 }
274 }
275
276 pub fn user(&mut self, user: &str) -> &mut Config {
280 self.user = Some(user.to_string());
281 self
282 }
283
284 pub fn get_user(&self) -> Option<&str> {
287 self.user.as_deref()
288 }
289
290 pub fn password<T>(&mut self, password: T) -> &mut Config
292 where
293 T: AsRef<[u8]>,
294 {
295 self.password = Some(password.as_ref().to_vec());
296 self
297 }
298
299 pub fn get_password(&self) -> Option<&[u8]> {
302 self.password.as_deref()
303 }
304
305 pub fn dbname(&mut self, dbname: &str) -> &mut Config {
309 self.dbname = Some(dbname.to_string());
310 self
311 }
312
313 pub fn get_dbname(&self) -> Option<&str> {
316 self.dbname.as_deref()
317 }
318
319 pub fn options(&mut self, options: &str) -> &mut Config {
321 self.options = Some(options.to_string());
322 self
323 }
324
325 pub fn get_options(&self) -> Option<&str> {
328 self.options.as_deref()
329 }
330
331 pub fn application_name(&mut self, application_name: &str) -> &mut Config {
333 self.application_name = Some(application_name.to_string());
334 self
335 }
336
337 pub fn get_application_name(&self) -> Option<&str> {
340 self.application_name.as_deref()
341 }
342
343 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
347 self.ssl_mode = ssl_mode;
348 self
349 }
350
351 pub fn get_ssl_mode(&self) -> SslMode {
353 self.ssl_mode
354 }
355
356 pub fn host(&mut self, host: &str) -> &mut Config {
362 #[cfg(unix)]
363 {
364 if host.starts_with('/') {
365 return self.host_path(host);
366 }
367 }
368
369 self.host.push(Host::Tcp(host.to_string()));
370 self
371 }
372
373 pub fn get_hosts(&self) -> &[Host] {
375 &self.host
376 }
377
378 pub fn get_hostaddrs(&self) -> &[IpAddr] {
380 self.hostaddr.deref()
381 }
382
383 #[cfg(unix)]
387 pub fn host_path<T>(&mut self, host: T) -> &mut Config
388 where
389 T: AsRef<Path>,
390 {
391 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
392 self
393 }
394
395 pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
400 self.hostaddr.push(hostaddr);
401 self
402 }
403
404 pub fn port(&mut self, port: u16) -> &mut Config {
410 self.port.push(port);
411 self
412 }
413
414 pub fn get_ports(&self) -> &[u16] {
416 &self.port
417 }
418
419 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
424 self.connect_timeout = Some(connect_timeout);
425 self
426 }
427
428 pub fn get_connect_timeout(&self) -> Option<&Duration> {
431 self.connect_timeout.as_ref()
432 }
433
434 pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
440 self.tcp_user_timeout = Some(tcp_user_timeout);
441 self
442 }
443
444 pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
447 self.tcp_user_timeout.as_ref()
448 }
449
450 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
454 self.keepalives = keepalives;
455 self
456 }
457
458 pub fn get_keepalives(&self) -> bool {
460 self.keepalives
461 }
462
463 #[cfg(not(target_arch = "wasm32"))]
467 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
468 self.keepalive_config.idle = keepalives_idle;
469 self
470 }
471
472 #[cfg(not(target_arch = "wasm32"))]
475 pub fn get_keepalives_idle(&self) -> Duration {
476 self.keepalive_config.idle
477 }
478
479 #[cfg(not(target_arch = "wasm32"))]
484 pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
485 self.keepalive_config.interval = Some(keepalives_interval);
486 self
487 }
488
489 #[cfg(not(target_arch = "wasm32"))]
491 pub fn get_keepalives_interval(&self) -> Option<Duration> {
492 self.keepalive_config.interval
493 }
494
495 #[cfg(not(target_arch = "wasm32"))]
499 pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
500 self.keepalive_config.retries = Some(keepalives_retries);
501 self
502 }
503
504 #[cfg(not(target_arch = "wasm32"))]
506 pub fn get_keepalives_retries(&self) -> Option<u32> {
507 self.keepalive_config.retries
508 }
509
510 pub fn target_session_attrs(
515 &mut self,
516 target_session_attrs: TargetSessionAttrs,
517 ) -> &mut Config {
518 self.target_session_attrs = target_session_attrs;
519 self
520 }
521
522 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
524 self.target_session_attrs
525 }
526
527 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
531 self.channel_binding = channel_binding;
532 self
533 }
534
535 pub fn get_channel_binding(&self) -> ChannelBinding {
537 self.channel_binding
538 }
539
540 pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
544 self.load_balance_hosts = load_balance_hosts;
545 self
546 }
547
548 pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
550 self.load_balance_hosts
551 }
552
553 pub fn load_balance(&mut self, load_balance: &str) -> &mut Config {
559 self.load_balance = load_balance.to_lowercase();
560 self
561 }
562
563 pub fn get_load_balance(&self) -> String {
567 self.load_balance.clone()
568 }
569
570 pub fn topology_keys(&mut self, topology_key: &str, priority: i64) -> &mut Config {
576 let current_zones: Option<&Vec<String>> = self.topology_keys.get(&priority);
577 if current_zones.is_none() {
578 let mut topology_vec: Vec<String> = Vec::new();
579 topology_vec.push(topology_key.to_owned());
580 self.topology_keys.insert(priority, topology_vec);
581 } else {
582 let mut current_zones_vec: Vec<String> = current_zones.unwrap().to_vec();
583 current_zones_vec.push(topology_key.to_owned());
584 self.topology_keys.insert(priority, current_zones_vec);
585 }
586 self
587 }
588
589 pub fn get_topology_keys(&self) -> HashMap<i64, Vec<String>> {
593 self.topology_keys.clone()
594 }
595
596 pub fn yb_servers_refresh_interval(
602 &mut self,
603 yb_servers_refresh_interval: Duration,
604 ) -> &mut Config {
605 self.yb_servers_refresh_interval = yb_servers_refresh_interval;
606 self
607 }
608
609 pub fn get_yb_servers_refresh_interval(&self) -> Duration {
613 self.yb_servers_refresh_interval
614 }
615
616 pub fn fallback_to_topology_keys_only(
622 &mut self,
623 fallback_to_topology_keys_only: bool,
624 ) -> &mut Config {
625 self.fallback_to_topology_keys_only = fallback_to_topology_keys_only;
626 self
627 }
628
629 pub fn get_fallback_to_topology_keys_only(&self) -> bool {
633 self.fallback_to_topology_keys_only
634 }
635
636 pub fn failed_host_reconnect_delay_secs(
642 &mut self,
643 failed_host_reconnect_delay_secs: Duration,
644 ) -> &mut Config {
645 self.failed_host_reconnect_delay_secs = failed_host_reconnect_delay_secs;
646 self
647 }
648
649 pub fn get_failed_host_reconnect_delay_secs(&self) -> Duration {
653 self.failed_host_reconnect_delay_secs
654 }
655
656 pub fn is_lb_valid(&self, lb: &str) -> bool {
658 match lb.to_lowercase().as_str(){
659 "only-rr" => true,
660 "only-primary"=> true,
661 "prefer-primary"=> true,
662 "prefer-rr"=> true,
663 "any"=> true,
664 "true"=> true,
665 "false"=> true,
666 _=>false,
667 }
668 }
669
670 pub fn is_valid(&self, zone: &str) -> bool {
672 let mut zones: Vec<&str> = zone.split(":").collect();
673 if zones.is_empty() || zones.len() > 2 {
674 return false;
675 }
676 let placement: Vec<&str> = zones[0].split(".").collect();
677 if placement.len() != 3 {
678 return false;
679 }
680 if zones.len() == 1 {
681 zones.push("1");
682 }
683 let priority = zones[1].parse::<i64>();
684 if priority.is_err() {
685 return false;
686 } else {
687 let priorityvalue = priority.unwrap();
688 if !(1..=10).contains(&priorityvalue) {
689 return false;
690 }
691 }
692 true
693 }
694
695 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
696 match key {
697 "user" => {
698 self.user(value);
699 }
700 "password" => {
701 self.password(value);
702 }
703 "dbname" => {
704 self.dbname(value);
705 }
706 "options" => {
707 self.options(value);
708 }
709 "application_name" => {
710 self.application_name(value);
711 }
712 "sslmode" => {
713 let mode = match value {
714 "disable" => SslMode::Disable,
715 "prefer" => SslMode::Prefer,
716 "require" => SslMode::Require,
717 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
718 };
719 self.ssl_mode(mode);
720 }
721 "host" => {
722 for host in value.split(',') {
723 self.host(host);
724 }
725 }
726 "hostaddr" => {
727 for hostaddr in value.split(',') {
728 let addr = hostaddr
729 .parse()
730 .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
731 self.hostaddr(addr);
732 }
733 }
734 "port" => {
735 for port in value.split(',') {
736 let port = if port.is_empty() {
737 5433
738 } else {
739 port.parse()
740 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
741 };
742 self.port(port);
743 }
744 }
745 "connect_timeout" => {
746 let timeout = value
747 .parse::<i64>()
748 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
749 if timeout > 0 {
750 self.connect_timeout(Duration::from_secs(timeout as u64));
751 }
752 }
753 "tcp_user_timeout" => {
754 let timeout = value
755 .parse::<i64>()
756 .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
757 if timeout > 0 {
758 self.tcp_user_timeout(Duration::from_secs(timeout as u64));
759 }
760 }
761 #[cfg(not(target_arch = "wasm32"))]
762 "keepalives" => {
763 let keepalives = value
764 .parse::<u64>()
765 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
766 self.keepalives(keepalives != 0);
767 }
768 #[cfg(not(target_arch = "wasm32"))]
769 "keepalives_idle" => {
770 let keepalives_idle = value
771 .parse::<i64>()
772 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
773 if keepalives_idle > 0 {
774 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
775 }
776 }
777 #[cfg(not(target_arch = "wasm32"))]
778 "keepalives_interval" => {
779 let keepalives_interval = value.parse::<i64>().map_err(|_| {
780 Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
781 })?;
782 if keepalives_interval > 0 {
783 self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
784 }
785 }
786 #[cfg(not(target_arch = "wasm32"))]
787 "keepalives_retries" => {
788 let keepalives_retries = value.parse::<u32>().map_err(|_| {
789 Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
790 })?;
791 self.keepalives_retries(keepalives_retries);
792 }
793 "target_session_attrs" => {
794 let target_session_attrs = match value {
795 "any" => TargetSessionAttrs::Any,
796 "read-write" => TargetSessionAttrs::ReadWrite,
797 _ => {
798 return Err(Error::config_parse(Box::new(InvalidValue(
799 "target_session_attrs",
800 ))));
801 }
802 };
803 self.target_session_attrs(target_session_attrs);
804 }
805 "channel_binding" => {
806 let channel_binding = match value {
807 "disable" => ChannelBinding::Disable,
808 "prefer" => ChannelBinding::Prefer,
809 "require" => ChannelBinding::Require,
810 _ => {
811 return Err(Error::config_parse(Box::new(InvalidValue(
812 "channel_binding",
813 ))))
814 }
815 };
816 self.channel_binding(channel_binding);
817 }
818 "load_balance_hosts" => {
819 let load_balance_hosts = match value {
820 "disable" => LoadBalanceHosts::Disable,
821 "random" => LoadBalanceHosts::Random,
822 _ => {
823 return Err(Error::config_parse(Box::new(InvalidValue(
824 "load_balance_hosts",
825 ))))
826 }
827 };
828 self.load_balance_hosts(load_balance_hosts);
829 }
830 "load_balance" => {
831 if self.is_lb_valid(value) {
832 self.load_balance(value);
833 } else {
834 return Err(Error::config_parse(Box::new(InvalidValue("load_balance"))));
835 }
836 }
837 "topology_keys" => {
838 for topology_keys in value.split(',') {
839 if self.is_valid(topology_keys) {
840 let mut zones: Vec<&str> = topology_keys.split(":").collect();
841 if zones.len() == 1 {
842 zones.push("1");
843 }
844 let priority = zones[1].parse::<i64>().unwrap();
845 self.topology_keys(zones[0], priority);
846 } else {
847 return Err(Error::config_parse(Box::new(InvalidValue("topology_keys"))));
848 }
849 }
850 }
851 "yb_servers_refresh_interval" => {
852 let refresh_interval = value.parse::<i64>().map_err(|_| {
853 Error::config_parse(Box::new(InvalidValue("yb_servers_refresh_interval")))
854 })?;
855 if (0..=600).contains(&refresh_interval) {
856 self.yb_servers_refresh_interval(Duration::from_secs(refresh_interval as u64));
857 }
858 }
859 "fallback_to_topology_keys_only" => {
860 let fallback_to_topology_keys_only = value.parse::<bool>().map_err(|_| {
861 Error::config_parse(Box::new(InvalidValue("fallback_to_topology_keys_only")))
862 })?;
863 self.fallback_to_topology_keys_only(fallback_to_topology_keys_only);
864 }
865 "failed_host_reconnect_delay_secs" => {
866 let failed_host_reconnect_delay_secs = value.parse::<i64>().map_err(|_| {
867 Error::config_parse(Box::new(InvalidValue("failed_host_reconnect_delay_secs")))
868 })?;
869 if (0..=60).contains(&failed_host_reconnect_delay_secs) {
870 self.failed_host_reconnect_delay_secs(Duration::from_secs(
871 failed_host_reconnect_delay_secs as u64,
872 ));
873 }
874 }
875 key => {
876 return Err(Error::config_parse(Box::new(UnknownOption(
877 key.to_string(),
878 ))));
879 }
880 }
881
882 Ok(())
883 }
884
885 #[cfg(feature = "runtime")]
889 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
890 where
891 T: MakeTlsConnect<Socket>,
892 {
893 if self.load_balance != "false" {
894 yb_connect(tls, self).await
895 } else {
896 connect(tls, self).await
897 }
898 }
899
900 pub async fn connect_raw<S, T>(
904 &self,
905 stream: S,
906 tls: T,
907 ) -> Result<(Client, Connection<S, T::Stream>), Error>
908 where
909 S: AsyncRead + AsyncWrite + Unpin,
910 T: TlsConnect<S>,
911 {
912 connect_raw(stream, tls, true, self).await
913 }
914}
915
916impl FromStr for Config {
917 type Err = Error;
918
919 fn from_str(s: &str) -> Result<Config, Error> {
920 match UrlParser::parse(s)? {
921 Some(config) => Ok(config),
922 None => Parser::parse(s),
923 }
924 }
925}
926
927impl fmt::Debug for Config {
929 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
930 struct Redaction {}
931 impl fmt::Debug for Redaction {
932 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
933 write!(f, "_")
934 }
935 }
936
937 let mut config_dbg = &mut f.debug_struct("Config");
938 config_dbg = config_dbg
939 .field("user", &self.user)
940 .field("password", &self.password.as_ref().map(|_| Redaction {}))
941 .field("dbname", &self.dbname)
942 .field("options", &self.options)
943 .field("application_name", &self.application_name)
944 .field("ssl_mode", &self.ssl_mode)
945 .field("host", &self.host)
946 .field("hostaddr", &self.hostaddr)
947 .field("port", &self.port)
948 .field("connect_timeout", &self.connect_timeout)
949 .field("tcp_user_timeout", &self.tcp_user_timeout)
950 .field("keepalives", &self.keepalives);
951
952 #[cfg(not(target_arch = "wasm32"))]
953 {
954 config_dbg = config_dbg
955 .field("keepalives_idle", &self.keepalive_config.idle)
956 .field("keepalives_interval", &self.keepalive_config.interval)
957 .field("keepalives_retries", &self.keepalive_config.retries);
958 }
959
960 config_dbg
961 .field("target_session_attrs", &self.target_session_attrs)
962 .field("channel_binding", &self.channel_binding)
963 .finish()
964 }
965}
966
967#[derive(Debug)]
968struct UnknownOption(String);
969
970impl fmt::Display for UnknownOption {
971 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
972 write!(fmt, "unknown option `{}`", self.0)
973 }
974}
975
976impl error::Error for UnknownOption {}
977
978#[derive(Debug)]
979struct InvalidValue(&'static str);
980
981impl fmt::Display for InvalidValue {
982 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
983 write!(fmt, "invalid value for option `{}`", self.0)
984 }
985}
986
987impl error::Error for InvalidValue {}
988
989struct Parser<'a> {
990 s: &'a str,
991 it: iter::Peekable<str::CharIndices<'a>>,
992}
993
994impl<'a> Parser<'a> {
995 fn parse(s: &'a str) -> Result<Config, Error> {
996 let mut parser = Parser {
997 s,
998 it: s.char_indices().peekable(),
999 };
1000
1001 let mut config = Config::new();
1002
1003 while let Some((key, value)) = parser.parameter()? {
1004 config.param(key, &value)?;
1005 }
1006
1007 Ok(config)
1008 }
1009
1010 fn skip_ws(&mut self) {
1011 self.take_while(char::is_whitespace);
1012 }
1013
1014 fn take_while<F>(&mut self, f: F) -> &'a str
1015 where
1016 F: Fn(char) -> bool,
1017 {
1018 let start = match self.it.peek() {
1019 Some(&(i, _)) => i,
1020 None => return "",
1021 };
1022
1023 loop {
1024 match self.it.peek() {
1025 Some(&(_, c)) if f(c) => {
1026 self.it.next();
1027 }
1028 Some(&(i, _)) => return &self.s[start..i],
1029 None => return &self.s[start..],
1030 }
1031 }
1032 }
1033
1034 fn eat(&mut self, target: char) -> Result<(), Error> {
1035 match self.it.next() {
1036 Some((_, c)) if c == target => Ok(()),
1037 Some((i, c)) => {
1038 let m = format!(
1039 "unexpected character at byte {}: expected `{}` but got `{}`",
1040 i, target, c
1041 );
1042 Err(Error::config_parse(m.into()))
1043 }
1044 None => Err(Error::config_parse("unexpected EOF".into())),
1045 }
1046 }
1047
1048 fn eat_if(&mut self, target: char) -> bool {
1049 match self.it.peek() {
1050 Some(&(_, c)) if c == target => {
1051 self.it.next();
1052 true
1053 }
1054 _ => false,
1055 }
1056 }
1057
1058 fn keyword(&mut self) -> Option<&'a str> {
1059 let s = self.take_while(|c| match c {
1060 c if c.is_whitespace() => false,
1061 '=' => false,
1062 _ => true,
1063 });
1064
1065 if s.is_empty() {
1066 None
1067 } else {
1068 Some(s)
1069 }
1070 }
1071
1072 fn value(&mut self) -> Result<String, Error> {
1073 let value = if self.eat_if('\'') {
1074 let value = self.quoted_value()?;
1075 self.eat('\'')?;
1076 value
1077 } else {
1078 self.simple_value()?
1079 };
1080
1081 Ok(value)
1082 }
1083
1084 fn simple_value(&mut self) -> Result<String, Error> {
1085 let mut value = String::new();
1086
1087 while let Some(&(_, c)) = self.it.peek() {
1088 if c.is_whitespace() {
1089 break;
1090 }
1091
1092 self.it.next();
1093 if c == '\\' {
1094 if let Some((_, c2)) = self.it.next() {
1095 value.push(c2);
1096 }
1097 } else {
1098 value.push(c);
1099 }
1100 }
1101
1102 if value.is_empty() {
1103 return Err(Error::config_parse("unexpected EOF".into()));
1104 }
1105
1106 Ok(value)
1107 }
1108
1109 fn quoted_value(&mut self) -> Result<String, Error> {
1110 let mut value = String::new();
1111
1112 while let Some(&(_, c)) = self.it.peek() {
1113 if c == '\'' {
1114 return Ok(value);
1115 }
1116
1117 self.it.next();
1118 if c == '\\' {
1119 if let Some((_, c2)) = self.it.next() {
1120 value.push(c2);
1121 }
1122 } else {
1123 value.push(c);
1124 }
1125 }
1126
1127 Err(Error::config_parse(
1128 "unterminated quoted connection parameter value".into(),
1129 ))
1130 }
1131
1132 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
1133 self.skip_ws();
1134 let keyword = match self.keyword() {
1135 Some(keyword) => keyword,
1136 None => return Ok(None),
1137 };
1138 self.skip_ws();
1139 self.eat('=')?;
1140 self.skip_ws();
1141 let value = self.value()?;
1142
1143 Ok(Some((keyword, value)))
1144 }
1145}
1146
1147struct UrlParser<'a> {
1149 s: &'a str,
1150 config: Config,
1151}
1152
1153impl<'a> UrlParser<'a> {
1154 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
1155 let s = match Self::remove_url_prefix(s) {
1156 Some(s) => s,
1157 None => return Ok(None),
1158 };
1159
1160 let mut parser = UrlParser {
1161 s,
1162 config: Config::new(),
1163 };
1164
1165 parser.parse_credentials()?;
1166 parser.parse_host()?;
1167 parser.parse_path()?;
1168 parser.parse_params()?;
1169
1170 Ok(Some(parser.config))
1171 }
1172
1173 fn remove_url_prefix(s: &str) -> Option<&str> {
1174 for prefix in &["postgres://", "postgresql://"] {
1175 if let Some(stripped) = s.strip_prefix(prefix) {
1176 return Some(stripped);
1177 }
1178 }
1179
1180 None
1181 }
1182
1183 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
1184 match self.s.find(end) {
1185 Some(pos) => {
1186 let (head, tail) = self.s.split_at(pos);
1187 self.s = tail;
1188 Some(head)
1189 }
1190 None => None,
1191 }
1192 }
1193
1194 fn take_all(&mut self) -> &'a str {
1195 mem::take(&mut self.s)
1196 }
1197
1198 fn eat_byte(&mut self) {
1199 self.s = &self.s[1..];
1200 }
1201
1202 fn parse_credentials(&mut self) -> Result<(), Error> {
1203 let creds = match self.take_until(&['@']) {
1204 Some(creds) => creds,
1205 None => return Ok(()),
1206 };
1207 self.eat_byte();
1208
1209 let mut it = creds.splitn(2, ':');
1210 let user = self.decode(it.next().unwrap())?;
1211 self.config.user(&user);
1212
1213 if let Some(password) = it.next() {
1214 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
1215 self.config.password(password);
1216 }
1217
1218 Ok(())
1219 }
1220
1221 fn parse_host(&mut self) -> Result<(), Error> {
1222 let host = match self.take_until(&['/', '?']) {
1223 Some(host) => host,
1224 None => self.take_all(),
1225 };
1226
1227 if host.is_empty() {
1228 return Ok(());
1229 }
1230
1231 for chunk in host.split(',') {
1232 let (host, port) = if chunk.starts_with('[') {
1233 let idx = match chunk.find(']') {
1234 Some(idx) => idx,
1235 None => return Err(Error::config_parse(InvalidValue("host").into())),
1236 };
1237
1238 let host = &chunk[1..idx];
1239 let remaining = &chunk[idx + 1..];
1240 let port = if let Some(port) = remaining.strip_prefix(':') {
1241 Some(port)
1242 } else if remaining.is_empty() {
1243 None
1244 } else {
1245 return Err(Error::config_parse(InvalidValue("host").into()));
1246 };
1247
1248 (host, port)
1249 } else {
1250 let mut it = chunk.splitn(2, ':');
1251 (it.next().unwrap(), it.next())
1252 };
1253
1254 self.host_param(host)?;
1255 let port = self.decode(port.unwrap_or("5433"))?;
1256 self.config.param("port", &port)?;
1257 }
1258
1259 Ok(())
1260 }
1261
1262 fn parse_path(&mut self) -> Result<(), Error> {
1263 if !self.s.starts_with('/') {
1264 return Ok(());
1265 }
1266 self.eat_byte();
1267
1268 let dbname = match self.take_until(&['?']) {
1269 Some(dbname) => dbname,
1270 None => self.take_all(),
1271 };
1272
1273 if !dbname.is_empty() {
1274 self.config.dbname(&self.decode(dbname)?);
1275 }
1276
1277 Ok(())
1278 }
1279
1280 fn parse_params(&mut self) -> Result<(), Error> {
1281 if !self.s.starts_with('?') {
1282 return Ok(());
1283 }
1284 self.eat_byte();
1285
1286 while !self.s.is_empty() {
1287 let key = match self.take_until(&['=']) {
1288 Some(key) => self.decode(key)?,
1289 None => return Err(Error::config_parse("unterminated parameter".into())),
1290 };
1291 self.eat_byte();
1292
1293 let value = match self.take_until(&['&']) {
1294 Some(value) => {
1295 self.eat_byte();
1296 value
1297 }
1298 None => self.take_all(),
1299 };
1300
1301 if key == "host" {
1302 self.host_param(value)?;
1303 } else {
1304 let value = self.decode(value)?;
1305 self.config.param(&key, &value)?;
1306 }
1307 }
1308
1309 Ok(())
1310 }
1311
1312 #[cfg(unix)]
1313 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1314 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1315 if decoded.first() == Some(&b'/') {
1316 self.config.host_path(OsStr::from_bytes(&decoded));
1317 } else {
1318 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1319 self.config.host(decoded);
1320 }
1321
1322 Ok(())
1323 }
1324
1325 #[cfg(not(unix))]
1326 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1327 let s = self.decode(s)?;
1328 self.config.param("host", &s)
1329 }
1330
1331 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1332 percent_encoding::percent_decode(s.as_bytes())
1333 .decode_utf8()
1334 .map_err(|e| Error::config_parse(e.into()))
1335 }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340 use std::net::IpAddr;
1341
1342 use crate::{config::Host, Config};
1343
1344 #[test]
1345 fn test_simple_parsing() {
1346 let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1347 let config = s.parse::<Config>().unwrap();
1348 assert_eq!(Some("pass_user"), config.get_user());
1349 assert_eq!(Some("postgres"), config.get_dbname());
1350 assert_eq!(
1351 [
1352 Host::Tcp("host1".to_string()),
1353 Host::Tcp("host2".to_string())
1354 ],
1355 config.get_hosts(),
1356 );
1357
1358 assert_eq!(
1359 [
1360 "127.0.0.1".parse::<IpAddr>().unwrap(),
1361 "127.0.0.2".parse::<IpAddr>().unwrap()
1362 ],
1363 config.get_hostaddrs(),
1364 );
1365
1366 assert_eq!(1, 1);
1367 }
1368
1369 #[test]
1370 fn test_invalid_hostaddr_parsing() {
1371 let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1372 s.parse::<Config>().err().unwrap();
1373 }
1374}