yb_tokio_postgres/
config.rs

1//! Connection configuration.
2
3#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect::yb_connect;
6use crate::connect_raw::connect_raw;
7#[cfg(not(target_arch = "wasm32"))]
8use crate::keepalive::KeepaliveConfig;
9#[cfg(feature = "runtime")]
10use crate::tls::MakeTlsConnect;
11use crate::tls::TlsConnect;
12#[cfg(feature = "runtime")]
13use crate::Socket;
14use crate::{Client, Connection, Error};
15use std::borrow::Cow;
16use std::collections::HashMap;
17#[cfg(unix)]
18use std::ffi::OsStr;
19use std::net::IpAddr;
20use std::ops::Deref;
21#[cfg(unix)]
22use std::os::unix::ffi::OsStrExt;
23#[cfg(unix)]
24use std::path::{Path, PathBuf};
25use std::str;
26use std::str::FromStr;
27use std::time::Duration;
28use std::{error, fmt, iter, mem};
29use tokio::io::{AsyncRead, AsyncWrite};
30
31/// Properties required of a session.
32#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33#[non_exhaustive]
34pub enum TargetSessionAttrs {
35    /// No special properties are required.
36    Any,
37    /// The session must allow writes.
38    ReadWrite,
39}
40
41/// TLS configuration.
42#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum SslMode {
45    /// Do not use TLS.
46    Disable,
47    /// Attempt to connect with TLS but allow sessions without.
48    Prefer,
49    /// Require the use of TLS.
50    Require,
51}
52
53/// Channel binding configuration.
54#[derive(Debug, Copy, Clone, PartialEq, Eq)]
55#[non_exhaustive]
56pub enum ChannelBinding {
57    /// Do not use channel binding.
58    Disable,
59    /// Attempt to use channel binding but allow sessions without.
60    Prefer,
61    /// Require the use of channel binding.
62    Require,
63}
64
65/// Load balancing configuration.
66#[derive(Debug, Copy, Clone, PartialEq, Eq)]
67#[non_exhaustive]
68pub enum LoadBalanceHosts {
69    /// Make connection attempts to hosts in the order provided.
70    Disable,
71    /// Make connection attempts to hosts in a random order.
72    Random,
73}
74
75/// A host specification.
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum Host {
78    /// A TCP hostname.
79    Tcp(String),
80    /// A path to a directory containing the server's Unix socket.
81    ///
82    /// This variant is only available on Unix platforms.
83    #[cfg(unix)]
84    Unix(PathBuf),
85}
86
87/// Connection configuration.
88///
89/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
90///
91/// # Key-Value
92///
93/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
94/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
95///
96/// ## Keys
97///
98/// * `user` - The username to authenticate with. Defaults to the user executing this process.
99/// * `password` - The password to authenticate with.
100/// * `dbname` - The name of the database to connect to. Defaults to the username.
101/// * `options` - Command line options used to configure the server.
102/// * `application_name` - Sets the `application_name` parameter on the server.
103/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
104///     if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
105/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
106///     path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
107///     can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
108///     with the `connect` method.
109/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
110///     e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
111///     If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
112///     or if host specifies an IP address, that value will be used directly.
113///     Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications
114///     with time constraints. However, a host name is required for TLS certificate verification.
115///     Specifically:
116///         * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address.
117///             The connection attempt will fail if the authentication method requires a host name;
118///         * If `host` is specified without `hostaddr`, a host name lookup occurs;
119///         * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address.
120///             The value for `host` is ignored unless the authentication method requires it,
121///             in which case it will be used as the host name.
122/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
123///     either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
124///     omitted or the empty string.
125/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
126///     can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
127/// * `tcp_user_timeout` - The time limit that transmitted data may remain unacknowledged before a connection is forcibly closed.
128///     This is ignored for Unix domain socket connections. It is only supported on systems where TCP_USER_TIMEOUT is available
129///     and will default to the system default if omitted or set to 0; on other systems, it has no effect.
130/// * `keepalives` - Controls the use of TCP keepalive. A value of 0 disables keepalive and nonzero integers enable it.
131///     This option is ignored when connecting with Unix sockets. Defaults to on.
132/// * `keepalives_idle` - The number of seconds of inactivity after which a keepalive message is sent to the server.
133///     This option is ignored when connecting with Unix sockets. Defaults to 2 hours.
134/// * `keepalives_interval` - The time interval between TCP keepalive probes.
135///     This option is ignored when connecting with Unix sockets.
136/// * `keepalives_retries` - The maximum number of TCP keepalive probes that will be sent before dropping a connection.
137///     This option is ignored when connecting with Unix sockets.
138/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
139///     the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
140///     in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
141/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
142///     binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
143///     If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
144/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and
145///     addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter
146///     is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to
147///     `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried
148///     in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults
149///     to `disable`.
150/// * `load_balance` -  Defaults to upstream driver behavior unless set to one of the allowed values (true or any, only-rr, only-primary,
151///     prefer-primary, prefer-rr and false) other than 'false'.
152/// * `topology_keys` - It takes a comma separated geo-location values. A single geo-location can be given as 'cloud.region.zone'.
153///     Multiple geo-locations too can be specified, separated by comma (,). Each placement value can be suffixed with a colon (:)
154///     followed by a preference value between 1 and 10. A preference value of :1 means it is a primary placement. A preference
155///     value of :2 means it is the first fallback placement and so on. If no preference value is provided, it is considered to
156///     be a primary placement (equivalent to one with preference value :1).
157/// * `yb_servers_refresh_interval` - Time interval, in seconds, between two attempts to refresh the information about cluster nodes.
158///     Default is 300. Valid values are integers between 0 and 600. Value 0 means refresh for each connection request. Any value
159///     outside this range is ignored and the default is used.
160/// * `fallback_to_topology_keys_only` - (default value: false) Applicable only for TopologyAware Load Balancing. When set to true,
161///     the smart driver does not attempt to connect to servers outside of primary and fallback placements specified via property.
162///     The default behaviour is to fallback to any available server in the entire cluster.
163/// * `failed_host_reconnect_delay_secs` - (default value: 5 seconds) The driver marks a server as failed with a timestamp, when it cannot
164///     connect to it. Later, whenever it refreshes the server list via yb_servers(), if it sees the failed server in the response,
165///     it marks the server as UP only if failed-host-reconnect-delay-secs time has elapsed. (The yb_servers() function does not remove
166///     a failed server immediately from its result and retains it for a while.)
167///
168/// ## Examples
169///
170/// ```not_rust
171/// host=localhost user=postgres connect_timeout=10 keepalives=0
172/// ```
173///
174/// ```not_rust
175/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
176/// ```
177///
178/// ```not_rust
179/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write
180/// ```
181///
182/// ```not_rust
183/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
184/// ```
185///
186/// # Url
187///
188/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
189/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
190/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
191/// as the path component of the URL specifies the database name.
192///
193/// ## Examples
194///
195/// ```not_rust
196/// postgresql://user@localhost
197/// ```
198///
199/// ```not_rust
200/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
201/// ```
202///
203/// ```not_rust
204/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
205/// ```
206///
207/// ```not_rust
208/// postgresql:///mydb?user=user&host=/var/lib/postgresql
209/// ```
210#[derive(Clone, PartialEq, Eq)]
211pub struct Config {
212    pub(crate) user: Option<String>,
213    pub(crate) password: Option<Vec<u8>>,
214    pub(crate) dbname: Option<String>,
215    pub(crate) options: Option<String>,
216    pub(crate) application_name: Option<String>,
217    pub(crate) ssl_mode: SslMode,
218    pub(crate) host: Vec<Host>,
219    pub(crate) hostaddr: Vec<IpAddr>,
220    pub(crate) port: Vec<u16>,
221    pub(crate) connect_timeout: Option<Duration>,
222    pub(crate) tcp_user_timeout: Option<Duration>,
223    pub(crate) keepalives: bool,
224    #[cfg(not(target_arch = "wasm32"))]
225    pub(crate) keepalive_config: KeepaliveConfig,
226    pub(crate) target_session_attrs: TargetSessionAttrs,
227    pub(crate) channel_binding: ChannelBinding,
228    pub(crate) load_balance_hosts: LoadBalanceHosts,
229    /// YugabyteDB Specific
230    pub(crate) load_balance: String,
231    pub(crate) topology_keys: HashMap<i64, Vec<String>>,
232    pub(crate) yb_servers_refresh_interval: Duration,
233    pub(crate) fallback_to_topology_keys_only: bool,
234    pub(crate) failed_host_reconnect_delay_secs: Duration,
235}
236
237impl Default for Config {
238    fn default() -> Config {
239        Config::new()
240    }
241}
242
243impl Config {
244    /// Creates a new configuration.
245    pub fn new() -> Config {
246        Config {
247            user: None,
248            password: None,
249            dbname: None,
250            options: None,
251            application_name: None,
252            ssl_mode: SslMode::Prefer,
253            host: vec![],
254            hostaddr: vec![],
255            port: vec![],
256            connect_timeout: None,
257            tcp_user_timeout: None,
258            keepalives: true,
259            #[cfg(not(target_arch = "wasm32"))]
260            keepalive_config: KeepaliveConfig {
261                idle: Duration::from_secs(2 * 60 * 60),
262                interval: None,
263                retries: None,
264            },
265            target_session_attrs: TargetSessionAttrs::Any,
266            channel_binding: ChannelBinding::Prefer,
267            load_balance_hosts: LoadBalanceHosts::Disable,
268            load_balance: String::from("false"),
269            topology_keys: HashMap::new(),
270            yb_servers_refresh_interval: Duration::new(300, 0),
271            fallback_to_topology_keys_only: false,
272            failed_host_reconnect_delay_secs: Duration::new(5, 0),
273        }
274    }
275
276    /// Sets the user to authenticate with.
277    ///
278    /// Defaults to the user executing this process.
279    pub fn user(&mut self, user: &str) -> &mut Config {
280        self.user = Some(user.to_string());
281        self
282    }
283
284    /// Gets the user to authenticate with, if one has been configured with
285    /// the `user` method.
286    pub fn get_user(&self) -> Option<&str> {
287        self.user.as_deref()
288    }
289
290    /// Sets the password to authenticate with.
291    pub fn password<T>(&mut self, password: T) -> &mut Config
292    where
293        T: AsRef<[u8]>,
294    {
295        self.password = Some(password.as_ref().to_vec());
296        self
297    }
298
299    /// Gets the password to authenticate with, if one has been configured with
300    /// the `password` method.
301    pub fn get_password(&self) -> Option<&[u8]> {
302        self.password.as_deref()
303    }
304
305    /// Sets the name of the database to connect to.
306    ///
307    /// Defaults to the user.
308    pub fn dbname(&mut self, dbname: &str) -> &mut Config {
309        self.dbname = Some(dbname.to_string());
310        self
311    }
312
313    /// Gets the name of the database to connect to, if one has been configured
314    /// with the `dbname` method.
315    pub fn get_dbname(&self) -> Option<&str> {
316        self.dbname.as_deref()
317    }
318
319    /// Sets command line options used to configure the server.
320    pub fn options(&mut self, options: &str) -> &mut Config {
321        self.options = Some(options.to_string());
322        self
323    }
324
325    /// Gets the command line options used to configure the server, if the
326    /// options have been set with the `options` method.
327    pub fn get_options(&self) -> Option<&str> {
328        self.options.as_deref()
329    }
330
331    /// Sets the value of the `application_name` runtime parameter.
332    pub fn application_name(&mut self, application_name: &str) -> &mut Config {
333        self.application_name = Some(application_name.to_string());
334        self
335    }
336
337    /// Gets the value of the `application_name` runtime parameter, if it has
338    /// been set with the `application_name` method.
339    pub fn get_application_name(&self) -> Option<&str> {
340        self.application_name.as_deref()
341    }
342
343    /// Sets the SSL configuration.
344    ///
345    /// Defaults to `prefer`.
346    pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
347        self.ssl_mode = ssl_mode;
348        self
349    }
350
351    /// Gets the SSL configuration.
352    pub fn get_ssl_mode(&self) -> SslMode {
353        self.ssl_mode
354    }
355
356    /// Adds a host to the configuration.
357    ///
358    /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
359    /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
360    /// There must be either no hosts, or the same number of hosts as hostaddrs.
361    pub fn host(&mut self, host: &str) -> &mut Config {
362        #[cfg(unix)]
363        {
364            if host.starts_with('/') {
365                return self.host_path(host);
366            }
367        }
368
369        self.host.push(Host::Tcp(host.to_string()));
370        self
371    }
372
373    /// Gets the hosts that have been added to the configuration with `host`.
374    pub fn get_hosts(&self) -> &[Host] {
375        &self.host
376    }
377
378    /// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
379    pub fn get_hostaddrs(&self) -> &[IpAddr] {
380        self.hostaddr.deref()
381    }
382
383    /// Adds a Unix socket host to the configuration.
384    ///
385    /// Unlike `host`, this method allows non-UTF8 paths.
386    #[cfg(unix)]
387    pub fn host_path<T>(&mut self, host: T) -> &mut Config
388    where
389        T: AsRef<Path>,
390    {
391        self.host.push(Host::Unix(host.as_ref().to_path_buf()));
392        self
393    }
394
395    /// Adds a hostaddr to the configuration.
396    ///
397    /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
398    /// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
399    pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
400        self.hostaddr.push(hostaddr);
401        self
402    }
403
404    /// Adds a port to the configuration.
405    ///
406    /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
407    /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
408    /// as hosts.
409    pub fn port(&mut self, port: u16) -> &mut Config {
410        self.port.push(port);
411        self
412    }
413
414    /// Gets the ports that have been added to the configuration with `port`.
415    pub fn get_ports(&self) -> &[u16] {
416        &self.port
417    }
418
419    /// Sets the timeout applied to socket-level connection attempts.
420    ///
421    /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
422    /// host separately. Defaults to no limit.
423    pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
424        self.connect_timeout = Some(connect_timeout);
425        self
426    }
427
428    /// Gets the connection timeout, if one has been set with the
429    /// `connect_timeout` method.
430    pub fn get_connect_timeout(&self) -> Option<&Duration> {
431        self.connect_timeout.as_ref()
432    }
433
434    /// Sets the TCP user timeout.
435    ///
436    /// This is ignored for Unix domain socket connections. It is only supported on systems where
437    /// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
438    /// on other systems, it has no effect.
439    pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
440        self.tcp_user_timeout = Some(tcp_user_timeout);
441        self
442    }
443
444    /// Gets the TCP user timeout, if one has been set with the
445    /// `user_timeout` method.
446    pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
447        self.tcp_user_timeout.as_ref()
448    }
449
450    /// Controls the use of TCP keepalive.
451    ///
452    /// This is ignored for Unix domain socket connections. Defaults to `true`.
453    pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
454        self.keepalives = keepalives;
455        self
456    }
457
458    /// Reports whether TCP keepalives will be used.
459    pub fn get_keepalives(&self) -> bool {
460        self.keepalives
461    }
462
463    /// Sets the amount of idle time before a keepalive packet is sent on the connection.
464    ///
465    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
466    #[cfg(not(target_arch = "wasm32"))]
467    pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
468        self.keepalive_config.idle = keepalives_idle;
469        self
470    }
471
472    /// Gets the configured amount of idle time before a keepalive packet will
473    /// be sent on the connection.
474    #[cfg(not(target_arch = "wasm32"))]
475    pub fn get_keepalives_idle(&self) -> Duration {
476        self.keepalive_config.idle
477    }
478
479    /// Sets the time interval between TCP keepalive probes.
480    /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field.
481    ///
482    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
483    #[cfg(not(target_arch = "wasm32"))]
484    pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
485        self.keepalive_config.interval = Some(keepalives_interval);
486        self
487    }
488
489    /// Gets the time interval between TCP keepalive probes.
490    #[cfg(not(target_arch = "wasm32"))]
491    pub fn get_keepalives_interval(&self) -> Option<Duration> {
492        self.keepalive_config.interval
493    }
494
495    /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
496    ///
497    /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
498    #[cfg(not(target_arch = "wasm32"))]
499    pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
500        self.keepalive_config.retries = Some(keepalives_retries);
501        self
502    }
503
504    /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
505    #[cfg(not(target_arch = "wasm32"))]
506    pub fn get_keepalives_retries(&self) -> Option<u32> {
507        self.keepalive_config.retries
508    }
509
510    /// Sets the requirements of the session.
511    ///
512    /// This can be used to connect to the primary server in a clustered database rather than one of the read-only
513    /// secondary servers. Defaults to `Any`.
514    pub fn target_session_attrs(
515        &mut self,
516        target_session_attrs: TargetSessionAttrs,
517    ) -> &mut Config {
518        self.target_session_attrs = target_session_attrs;
519        self
520    }
521
522    /// Gets the requirements of the session.
523    pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
524        self.target_session_attrs
525    }
526
527    /// Sets the channel binding behavior.
528    ///
529    /// Defaults to `prefer`.
530    pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
531        self.channel_binding = channel_binding;
532        self
533    }
534
535    /// Gets the channel binding behavior.
536    pub fn get_channel_binding(&self) -> ChannelBinding {
537        self.channel_binding
538    }
539
540    /// Sets the host load balancing behavior.
541    ///
542    /// Defaults to `disable`.
543    pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
544        self.load_balance_hosts = load_balance_hosts;
545        self
546    }
547
548    /// Gets the host load balancing behavior.
549    pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
550        self.load_balance_hosts
551    }
552
553    /// YugabyteDB Specific.
554    ///
555    /// Sets the load balance parameter.
556    ///
557    /// Defaults to false.
558    pub fn load_balance(&mut self, load_balance: &str) -> &mut Config {
559        self.load_balance = load_balance.to_lowercase();
560        self
561    }
562
563    /// YugabyteDB Specific.
564    ///
565    /// Gets the load balance value
566    pub fn get_load_balance(&self) -> String {
567        self.load_balance.clone()
568    }
569
570    /// YugabyteDB Specific.
571    ///
572    /// Sets the topology key parameter.
573    ///
574    /// Defaults to Hashmap::new().
575    pub fn topology_keys(&mut self, topology_key: &str, priority: i64) -> &mut Config {
576        let current_zones: Option<&Vec<String>> = self.topology_keys.get(&priority);
577        if current_zones.is_none() {
578            let mut topology_vec: Vec<String> = Vec::new();
579            topology_vec.push(topology_key.to_owned());
580            self.topology_keys.insert(priority, topology_vec);
581        } else {
582            let mut current_zones_vec: Vec<String> = current_zones.unwrap().to_vec();
583            current_zones_vec.push(topology_key.to_owned());
584            self.topology_keys.insert(priority, current_zones_vec);
585        }
586        self
587    }
588
589    /// YugabyteDB Specific.
590    ///
591    /// Gets the host topology keys value.
592    pub fn get_topology_keys(&self) -> HashMap<i64, Vec<String>> {
593        self.topology_keys.clone()
594    }
595
596    /// YugabyteDB Specific.
597    ///
598    /// Sets the yb_servers_refresh_interval parameter.
599    ///
600    /// Defaults to 300 sec.
601    pub fn yb_servers_refresh_interval(
602        &mut self,
603        yb_servers_refresh_interval: Duration,
604    ) -> &mut Config {
605        self.yb_servers_refresh_interval = yb_servers_refresh_interval;
606        self
607    }
608
609    /// YugabyteDB Specific.
610    ///
611    /// Gets the yb_servers_refresh_interval value.
612    pub fn get_yb_servers_refresh_interval(&self) -> Duration {
613        self.yb_servers_refresh_interval
614    }
615
616    /// YugabyteDB Specific.
617    ///
618    /// Sets the fallback_to_topology_keys_only parameter.
619    ///
620    /// Defaults to false.
621    pub fn fallback_to_topology_keys_only(
622        &mut self,
623        fallback_to_topology_keys_only: bool,
624    ) -> &mut Config {
625        self.fallback_to_topology_keys_only = fallback_to_topology_keys_only;
626        self
627    }
628
629    /// YugabyteDB Specific.
630    ///
631    /// Gets the fallback_to_topology_keys_only value.
632    pub fn get_fallback_to_topology_keys_only(&self) -> bool {
633        self.fallback_to_topology_keys_only
634    }
635
636    /// YugabyteDB Specific.
637    ///
638    /// Sets the failed_host_reconnect_delay_secs parameter.
639    ///
640    /// Defaults to 5 sec.
641    pub fn failed_host_reconnect_delay_secs(
642        &mut self,
643        failed_host_reconnect_delay_secs: Duration,
644    ) -> &mut Config {
645        self.failed_host_reconnect_delay_secs = failed_host_reconnect_delay_secs;
646        self
647    }
648
649    /// YugabyteDB Specific.
650    ///
651    /// Gets the failed_host_reconnect_delay_secs value.
652    pub fn get_failed_host_reconnect_delay_secs(&self) -> Duration {
653        self.failed_host_reconnect_delay_secs
654    }
655
656    ///Check if the given load_balnce value if one of the allowed values.
657    pub fn is_lb_valid(&self, lb: &str) -> bool {
658        match lb.to_lowercase().as_str(){
659            "only-rr" => true,
660            "only-primary"=> true,
661            "prefer-primary"=> true,
662            "prefer-rr"=> true,
663            "any"=> true,
664            "true"=> true,
665            "false"=> true,
666            _=>false,
667        }
668    }
669
670    ///Check if a given zone in Topology keys is valid
671    pub fn is_valid(&self, zone: &str) -> bool {
672        let mut zones: Vec<&str> = zone.split(":").collect();
673        if zones.is_empty() || zones.len() > 2 {
674            return false;
675        }
676        let placement: Vec<&str> = zones[0].split(".").collect();
677        if placement.len() != 3 {
678            return false;
679        }
680        if zones.len() == 1 {
681            zones.push("1");
682        }
683        let priority = zones[1].parse::<i64>();
684        if priority.is_err() {
685            return false;
686        } else {
687            let priorityvalue = priority.unwrap();
688            if !(1..=10).contains(&priorityvalue) {
689                return false;
690            }
691        }
692        true
693    }
694
695    fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
696        match key {
697            "user" => {
698                self.user(value);
699            }
700            "password" => {
701                self.password(value);
702            }
703            "dbname" => {
704                self.dbname(value);
705            }
706            "options" => {
707                self.options(value);
708            }
709            "application_name" => {
710                self.application_name(value);
711            }
712            "sslmode" => {
713                let mode = match value {
714                    "disable" => SslMode::Disable,
715                    "prefer" => SslMode::Prefer,
716                    "require" => SslMode::Require,
717                    _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
718                };
719                self.ssl_mode(mode);
720            }
721            "host" => {
722                for host in value.split(',') {
723                    self.host(host);
724                }
725            }
726            "hostaddr" => {
727                for hostaddr in value.split(',') {
728                    let addr = hostaddr
729                        .parse()
730                        .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
731                    self.hostaddr(addr);
732                }
733            }
734            "port" => {
735                for port in value.split(',') {
736                    let port = if port.is_empty() {
737                        5433
738                    } else {
739                        port.parse()
740                            .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
741                    };
742                    self.port(port);
743                }
744            }
745            "connect_timeout" => {
746                let timeout = value
747                    .parse::<i64>()
748                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
749                if timeout > 0 {
750                    self.connect_timeout(Duration::from_secs(timeout as u64));
751                }
752            }
753            "tcp_user_timeout" => {
754                let timeout = value
755                    .parse::<i64>()
756                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
757                if timeout > 0 {
758                    self.tcp_user_timeout(Duration::from_secs(timeout as u64));
759                }
760            }
761            #[cfg(not(target_arch = "wasm32"))]
762            "keepalives" => {
763                let keepalives = value
764                    .parse::<u64>()
765                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
766                self.keepalives(keepalives != 0);
767            }
768            #[cfg(not(target_arch = "wasm32"))]
769            "keepalives_idle" => {
770                let keepalives_idle = value
771                    .parse::<i64>()
772                    .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
773                if keepalives_idle > 0 {
774                    self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
775                }
776            }
777            #[cfg(not(target_arch = "wasm32"))]
778            "keepalives_interval" => {
779                let keepalives_interval = value.parse::<i64>().map_err(|_| {
780                    Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
781                })?;
782                if keepalives_interval > 0 {
783                    self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
784                }
785            }
786            #[cfg(not(target_arch = "wasm32"))]
787            "keepalives_retries" => {
788                let keepalives_retries = value.parse::<u32>().map_err(|_| {
789                    Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
790                })?;
791                self.keepalives_retries(keepalives_retries);
792            }
793            "target_session_attrs" => {
794                let target_session_attrs = match value {
795                    "any" => TargetSessionAttrs::Any,
796                    "read-write" => TargetSessionAttrs::ReadWrite,
797                    _ => {
798                        return Err(Error::config_parse(Box::new(InvalidValue(
799                            "target_session_attrs",
800                        ))));
801                    }
802                };
803                self.target_session_attrs(target_session_attrs);
804            }
805            "channel_binding" => {
806                let channel_binding = match value {
807                    "disable" => ChannelBinding::Disable,
808                    "prefer" => ChannelBinding::Prefer,
809                    "require" => ChannelBinding::Require,
810                    _ => {
811                        return Err(Error::config_parse(Box::new(InvalidValue(
812                            "channel_binding",
813                        ))))
814                    }
815                };
816                self.channel_binding(channel_binding);
817            }
818            "load_balance_hosts" => {
819                let load_balance_hosts = match value {
820                    "disable" => LoadBalanceHosts::Disable,
821                    "random" => LoadBalanceHosts::Random,
822                    _ => {
823                        return Err(Error::config_parse(Box::new(InvalidValue(
824                            "load_balance_hosts",
825                        ))))
826                    }
827                };
828                self.load_balance_hosts(load_balance_hosts);
829            }
830            "load_balance" => {
831                if self.is_lb_valid(value) {
832                    self.load_balance(value);
833                } else {
834                    return Err(Error::config_parse(Box::new(InvalidValue("load_balance"))));
835                }
836            }
837            "topology_keys" => {
838                for topology_keys in value.split(',') {
839                    if self.is_valid(topology_keys) {
840                        let mut zones: Vec<&str> = topology_keys.split(":").collect();
841                        if zones.len() == 1 {
842                            zones.push("1");
843                        }
844                        let priority = zones[1].parse::<i64>().unwrap();
845                        self.topology_keys(zones[0], priority);
846                    } else {
847                        return Err(Error::config_parse(Box::new(InvalidValue("topology_keys"))));
848                    }
849                }
850            }
851            "yb_servers_refresh_interval" => {
852                let refresh_interval = value.parse::<i64>().map_err(|_| {
853                    Error::config_parse(Box::new(InvalidValue("yb_servers_refresh_interval")))
854                })?;
855                if (0..=600).contains(&refresh_interval) {
856                    self.yb_servers_refresh_interval(Duration::from_secs(refresh_interval as u64));
857                }
858            }
859            "fallback_to_topology_keys_only" => {
860                let fallback_to_topology_keys_only = value.parse::<bool>().map_err(|_| {
861                    Error::config_parse(Box::new(InvalidValue("fallback_to_topology_keys_only")))
862                })?;
863                self.fallback_to_topology_keys_only(fallback_to_topology_keys_only);
864            }
865            "failed_host_reconnect_delay_secs" => {
866                let failed_host_reconnect_delay_secs = value.parse::<i64>().map_err(|_| {
867                    Error::config_parse(Box::new(InvalidValue("failed_host_reconnect_delay_secs")))
868                })?;
869                if (0..=60).contains(&failed_host_reconnect_delay_secs) {
870                    self.failed_host_reconnect_delay_secs(Duration::from_secs(
871                        failed_host_reconnect_delay_secs as u64,
872                    ));
873                }
874            }
875            key => {
876                return Err(Error::config_parse(Box::new(UnknownOption(
877                    key.to_string(),
878                ))));
879            }
880        }
881
882        Ok(())
883    }
884
885    /// Opens a connection to a PostgreSQL database.
886    ///
887    /// Requires the `runtime` Cargo feature (enabled by default).
888    #[cfg(feature = "runtime")]
889    pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
890    where
891        T: MakeTlsConnect<Socket>,
892    {
893        if self.load_balance != "false" {
894            yb_connect(tls, self).await
895        } else {
896            connect(tls, self).await
897        }
898    }
899
900    /// Connects to a PostgreSQL database over an arbitrary stream.
901    ///
902    /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
903    pub async fn connect_raw<S, T>(
904        &self,
905        stream: S,
906        tls: T,
907    ) -> Result<(Client, Connection<S, T::Stream>), Error>
908    where
909        S: AsyncRead + AsyncWrite + Unpin,
910        T: TlsConnect<S>,
911    {
912        connect_raw(stream, tls, true, self).await
913    }
914}
915
916impl FromStr for Config {
917    type Err = Error;
918
919    fn from_str(s: &str) -> Result<Config, Error> {
920        match UrlParser::parse(s)? {
921            Some(config) => Ok(config),
922            None => Parser::parse(s),
923        }
924    }
925}
926
927// Omit password from debug output
928impl fmt::Debug for Config {
929    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
930        struct Redaction {}
931        impl fmt::Debug for Redaction {
932            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
933                write!(f, "_")
934            }
935        }
936
937        let mut config_dbg = &mut f.debug_struct("Config");
938        config_dbg = config_dbg
939            .field("user", &self.user)
940            .field("password", &self.password.as_ref().map(|_| Redaction {}))
941            .field("dbname", &self.dbname)
942            .field("options", &self.options)
943            .field("application_name", &self.application_name)
944            .field("ssl_mode", &self.ssl_mode)
945            .field("host", &self.host)
946            .field("hostaddr", &self.hostaddr)
947            .field("port", &self.port)
948            .field("connect_timeout", &self.connect_timeout)
949            .field("tcp_user_timeout", &self.tcp_user_timeout)
950            .field("keepalives", &self.keepalives);
951
952        #[cfg(not(target_arch = "wasm32"))]
953        {
954            config_dbg = config_dbg
955                .field("keepalives_idle", &self.keepalive_config.idle)
956                .field("keepalives_interval", &self.keepalive_config.interval)
957                .field("keepalives_retries", &self.keepalive_config.retries);
958        }
959
960        config_dbg
961            .field("target_session_attrs", &self.target_session_attrs)
962            .field("channel_binding", &self.channel_binding)
963            .finish()
964    }
965}
966
967#[derive(Debug)]
968struct UnknownOption(String);
969
970impl fmt::Display for UnknownOption {
971    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
972        write!(fmt, "unknown option `{}`", self.0)
973    }
974}
975
976impl error::Error for UnknownOption {}
977
978#[derive(Debug)]
979struct InvalidValue(&'static str);
980
981impl fmt::Display for InvalidValue {
982    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
983        write!(fmt, "invalid value for option `{}`", self.0)
984    }
985}
986
987impl error::Error for InvalidValue {}
988
989struct Parser<'a> {
990    s: &'a str,
991    it: iter::Peekable<str::CharIndices<'a>>,
992}
993
994impl<'a> Parser<'a> {
995    fn parse(s: &'a str) -> Result<Config, Error> {
996        let mut parser = Parser {
997            s,
998            it: s.char_indices().peekable(),
999        };
1000
1001        let mut config = Config::new();
1002
1003        while let Some((key, value)) = parser.parameter()? {
1004            config.param(key, &value)?;
1005        }
1006
1007        Ok(config)
1008    }
1009
1010    fn skip_ws(&mut self) {
1011        self.take_while(char::is_whitespace);
1012    }
1013
1014    fn take_while<F>(&mut self, f: F) -> &'a str
1015    where
1016        F: Fn(char) -> bool,
1017    {
1018        let start = match self.it.peek() {
1019            Some(&(i, _)) => i,
1020            None => return "",
1021        };
1022
1023        loop {
1024            match self.it.peek() {
1025                Some(&(_, c)) if f(c) => {
1026                    self.it.next();
1027                }
1028                Some(&(i, _)) => return &self.s[start..i],
1029                None => return &self.s[start..],
1030            }
1031        }
1032    }
1033
1034    fn eat(&mut self, target: char) -> Result<(), Error> {
1035        match self.it.next() {
1036            Some((_, c)) if c == target => Ok(()),
1037            Some((i, c)) => {
1038                let m = format!(
1039                    "unexpected character at byte {}: expected `{}` but got `{}`",
1040                    i, target, c
1041                );
1042                Err(Error::config_parse(m.into()))
1043            }
1044            None => Err(Error::config_parse("unexpected EOF".into())),
1045        }
1046    }
1047
1048    fn eat_if(&mut self, target: char) -> bool {
1049        match self.it.peek() {
1050            Some(&(_, c)) if c == target => {
1051                self.it.next();
1052                true
1053            }
1054            _ => false,
1055        }
1056    }
1057
1058    fn keyword(&mut self) -> Option<&'a str> {
1059        let s = self.take_while(|c| match c {
1060            c if c.is_whitespace() => false,
1061            '=' => false,
1062            _ => true,
1063        });
1064
1065        if s.is_empty() {
1066            None
1067        } else {
1068            Some(s)
1069        }
1070    }
1071
1072    fn value(&mut self) -> Result<String, Error> {
1073        let value = if self.eat_if('\'') {
1074            let value = self.quoted_value()?;
1075            self.eat('\'')?;
1076            value
1077        } else {
1078            self.simple_value()?
1079        };
1080
1081        Ok(value)
1082    }
1083
1084    fn simple_value(&mut self) -> Result<String, Error> {
1085        let mut value = String::new();
1086
1087        while let Some(&(_, c)) = self.it.peek() {
1088            if c.is_whitespace() {
1089                break;
1090            }
1091
1092            self.it.next();
1093            if c == '\\' {
1094                if let Some((_, c2)) = self.it.next() {
1095                    value.push(c2);
1096                }
1097            } else {
1098                value.push(c);
1099            }
1100        }
1101
1102        if value.is_empty() {
1103            return Err(Error::config_parse("unexpected EOF".into()));
1104        }
1105
1106        Ok(value)
1107    }
1108
1109    fn quoted_value(&mut self) -> Result<String, Error> {
1110        let mut value = String::new();
1111
1112        while let Some(&(_, c)) = self.it.peek() {
1113            if c == '\'' {
1114                return Ok(value);
1115            }
1116
1117            self.it.next();
1118            if c == '\\' {
1119                if let Some((_, c2)) = self.it.next() {
1120                    value.push(c2);
1121                }
1122            } else {
1123                value.push(c);
1124            }
1125        }
1126
1127        Err(Error::config_parse(
1128            "unterminated quoted connection parameter value".into(),
1129        ))
1130    }
1131
1132    fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
1133        self.skip_ws();
1134        let keyword = match self.keyword() {
1135            Some(keyword) => keyword,
1136            None => return Ok(None),
1137        };
1138        self.skip_ws();
1139        self.eat('=')?;
1140        self.skip_ws();
1141        let value = self.value()?;
1142
1143        Ok(Some((keyword, value)))
1144    }
1145}
1146
1147// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
1148struct UrlParser<'a> {
1149    s: &'a str,
1150    config: Config,
1151}
1152
1153impl<'a> UrlParser<'a> {
1154    fn parse(s: &'a str) -> Result<Option<Config>, Error> {
1155        let s = match Self::remove_url_prefix(s) {
1156            Some(s) => s,
1157            None => return Ok(None),
1158        };
1159
1160        let mut parser = UrlParser {
1161            s,
1162            config: Config::new(),
1163        };
1164
1165        parser.parse_credentials()?;
1166        parser.parse_host()?;
1167        parser.parse_path()?;
1168        parser.parse_params()?;
1169
1170        Ok(Some(parser.config))
1171    }
1172
1173    fn remove_url_prefix(s: &str) -> Option<&str> {
1174        for prefix in &["postgres://", "postgresql://"] {
1175            if let Some(stripped) = s.strip_prefix(prefix) {
1176                return Some(stripped);
1177            }
1178        }
1179
1180        None
1181    }
1182
1183    fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
1184        match self.s.find(end) {
1185            Some(pos) => {
1186                let (head, tail) = self.s.split_at(pos);
1187                self.s = tail;
1188                Some(head)
1189            }
1190            None => None,
1191        }
1192    }
1193
1194    fn take_all(&mut self) -> &'a str {
1195        mem::take(&mut self.s)
1196    }
1197
1198    fn eat_byte(&mut self) {
1199        self.s = &self.s[1..];
1200    }
1201
1202    fn parse_credentials(&mut self) -> Result<(), Error> {
1203        let creds = match self.take_until(&['@']) {
1204            Some(creds) => creds,
1205            None => return Ok(()),
1206        };
1207        self.eat_byte();
1208
1209        let mut it = creds.splitn(2, ':');
1210        let user = self.decode(it.next().unwrap())?;
1211        self.config.user(&user);
1212
1213        if let Some(password) = it.next() {
1214            let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
1215            self.config.password(password);
1216        }
1217
1218        Ok(())
1219    }
1220
1221    fn parse_host(&mut self) -> Result<(), Error> {
1222        let host = match self.take_until(&['/', '?']) {
1223            Some(host) => host,
1224            None => self.take_all(),
1225        };
1226
1227        if host.is_empty() {
1228            return Ok(());
1229        }
1230
1231        for chunk in host.split(',') {
1232            let (host, port) = if chunk.starts_with('[') {
1233                let idx = match chunk.find(']') {
1234                    Some(idx) => idx,
1235                    None => return Err(Error::config_parse(InvalidValue("host").into())),
1236                };
1237
1238                let host = &chunk[1..idx];
1239                let remaining = &chunk[idx + 1..];
1240                let port = if let Some(port) = remaining.strip_prefix(':') {
1241                    Some(port)
1242                } else if remaining.is_empty() {
1243                    None
1244                } else {
1245                    return Err(Error::config_parse(InvalidValue("host").into()));
1246                };
1247
1248                (host, port)
1249            } else {
1250                let mut it = chunk.splitn(2, ':');
1251                (it.next().unwrap(), it.next())
1252            };
1253
1254            self.host_param(host)?;
1255            let port = self.decode(port.unwrap_or("5433"))?;
1256            self.config.param("port", &port)?;
1257        }
1258
1259        Ok(())
1260    }
1261
1262    fn parse_path(&mut self) -> Result<(), Error> {
1263        if !self.s.starts_with('/') {
1264            return Ok(());
1265        }
1266        self.eat_byte();
1267
1268        let dbname = match self.take_until(&['?']) {
1269            Some(dbname) => dbname,
1270            None => self.take_all(),
1271        };
1272
1273        if !dbname.is_empty() {
1274            self.config.dbname(&self.decode(dbname)?);
1275        }
1276
1277        Ok(())
1278    }
1279
1280    fn parse_params(&mut self) -> Result<(), Error> {
1281        if !self.s.starts_with('?') {
1282            return Ok(());
1283        }
1284        self.eat_byte();
1285
1286        while !self.s.is_empty() {
1287            let key = match self.take_until(&['=']) {
1288                Some(key) => self.decode(key)?,
1289                None => return Err(Error::config_parse("unterminated parameter".into())),
1290            };
1291            self.eat_byte();
1292
1293            let value = match self.take_until(&['&']) {
1294                Some(value) => {
1295                    self.eat_byte();
1296                    value
1297                }
1298                None => self.take_all(),
1299            };
1300
1301            if key == "host" {
1302                self.host_param(value)?;
1303            } else {
1304                let value = self.decode(value)?;
1305                self.config.param(&key, &value)?;
1306            }
1307        }
1308
1309        Ok(())
1310    }
1311
1312    #[cfg(unix)]
1313    fn host_param(&mut self, s: &str) -> Result<(), Error> {
1314        let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1315        if decoded.first() == Some(&b'/') {
1316            self.config.host_path(OsStr::from_bytes(&decoded));
1317        } else {
1318            let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1319            self.config.host(decoded);
1320        }
1321
1322        Ok(())
1323    }
1324
1325    #[cfg(not(unix))]
1326    fn host_param(&mut self, s: &str) -> Result<(), Error> {
1327        let s = self.decode(s)?;
1328        self.config.param("host", &s)
1329    }
1330
1331    fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1332        percent_encoding::percent_decode(s.as_bytes())
1333            .decode_utf8()
1334            .map_err(|e| Error::config_parse(e.into()))
1335    }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340    use std::net::IpAddr;
1341
1342    use crate::{config::Host, Config};
1343
1344    #[test]
1345    fn test_simple_parsing() {
1346        let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1347        let config = s.parse::<Config>().unwrap();
1348        assert_eq!(Some("pass_user"), config.get_user());
1349        assert_eq!(Some("postgres"), config.get_dbname());
1350        assert_eq!(
1351            [
1352                Host::Tcp("host1".to_string()),
1353                Host::Tcp("host2".to_string())
1354            ],
1355            config.get_hosts(),
1356        );
1357
1358        assert_eq!(
1359            [
1360                "127.0.0.1".parse::<IpAddr>().unwrap(),
1361                "127.0.0.2".parse::<IpAddr>().unwrap()
1362            ],
1363            config.get_hostaddrs(),
1364        );
1365
1366        assert_eq!(1, 1);
1367    }
1368
1369    #[test]
1370    fn test_invalid_hostaddr_parsing() {
1371        let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1372        s.parse::<Config>().err().unwrap();
1373    }
1374}