xitca_postgres/
config.rs

1//! Connection configuration. copy/paste from `tokio-postgres`
2
3use core::{fmt, iter, mem, str};
4
5use std::{
6    borrow::Cow,
7    path::{Path, PathBuf},
8};
9
10use super::{error::Error, session::TargetSessionAttrs};
11
12#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum SslMode {
15    /// Do not use TLS.
16    Disable,
17    /// Attempt to connect with TLS but allow sessions without.
18    #[default]
19    Prefer,
20    /// Require the use of TLS.
21    Require,
22}
23
24/// TLS negotiation configuration
25#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum SslNegotiation {
28    /// Use PostgreSQL SslRequest for Ssl negotiation
29    #[default]
30    Postgres,
31    /// Start Ssl handshake without negotiation, only works for PostgreSQL 17+
32    Direct,
33}
34
35/// A host specification.
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub enum Host {
38    /// A TCP hostname.
39    Tcp(Box<str>),
40    Quic(Box<str>),
41    /// A Unix hostname.
42    Unix(PathBuf),
43}
44
45#[derive(Clone, Eq, PartialEq)]
46pub struct Config {
47    pub(crate) user: Option<Box<str>>,
48    pub(crate) password: Option<Box<[u8]>>,
49    pub(crate) dbname: Option<Box<str>>,
50    pub(crate) options: Option<Box<str>>,
51    pub(crate) application_name: Option<Box<str>>,
52    pub(crate) ssl_mode: SslMode,
53    pub(crate) ssl_negotiation: SslNegotiation,
54    pub(crate) host: Vec<Host>,
55    pub(crate) port: Vec<u16>,
56    target_session_attrs: TargetSessionAttrs,
57    tls_server_end_point: Option<Box<[u8]>>,
58}
59
60impl Default for Config {
61    fn default() -> Config {
62        Config::new()
63    }
64}
65
66impl Config {
67    /// Creates a new configuration.
68    pub const fn new() -> Config {
69        Config {
70            user: None,
71            password: None,
72            dbname: None,
73            options: None,
74            application_name: None,
75            ssl_mode: SslMode::Prefer,
76            ssl_negotiation: SslNegotiation::Postgres,
77            host: Vec::new(),
78            port: Vec::new(),
79            target_session_attrs: TargetSessionAttrs::Any,
80            tls_server_end_point: None,
81        }
82    }
83
84    /// Sets the user to authenticate with.
85    ///
86    /// Required.
87    pub fn user(&mut self, user: &str) -> &mut Config {
88        self.user = Some(Box::from(user));
89        self
90    }
91
92    /// Gets the user to authenticate with, if one has been configured with
93    /// the `user` method.
94    pub fn get_user(&self) -> Option<&str> {
95        self.user.as_deref()
96    }
97
98    /// Sets the password to authenticate with.
99    pub fn password<T>(&mut self, password: T) -> &mut Config
100    where
101        T: AsRef<[u8]>,
102    {
103        self.password = Some(Box::from(password.as_ref()));
104        self
105    }
106
107    /// Gets the password to authenticate with, if one has been configured with
108    /// the `password` method.
109    pub fn get_password(&self) -> Option<&[u8]> {
110        self.password.as_deref()
111    }
112
113    /// Sets the name of the database to connect to.
114    ///
115    /// Defaults to the user.
116    pub fn dbname(&mut self, dbname: &str) -> &mut Config {
117        self.dbname = Some(Box::from(dbname));
118        self
119    }
120
121    /// Gets the name of the database to connect to, if one has been configured
122    /// with the `dbname` method.
123    pub fn get_dbname(&self) -> Option<&str> {
124        self.dbname.as_deref()
125    }
126
127    /// Sets command line options used to configure the server.
128    pub fn options(&mut self, options: &str) -> &mut Config {
129        self.options = Some(Box::from(options));
130        self
131    }
132
133    /// Gets the command line options used to configure the server, if the
134    /// options have been set with the `options` method.
135    pub fn get_options(&self) -> Option<&str> {
136        self.options.as_deref()
137    }
138
139    /// Sets the value of the `application_name` runtime parameter.
140    pub fn application_name(&mut self, application_name: &str) -> &mut Config {
141        self.application_name = Some(Box::from(application_name));
142        self
143    }
144
145    /// Gets the value of the `application_name` runtime parameter, if it has
146    /// been set with the `application_name` method.
147    pub fn get_application_name(&self) -> Option<&str> {
148        self.application_name.as_deref()
149    }
150
151    /// Sets the SSL configuration.
152    ///
153    /// Defaults to `prefer`.
154    pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
155        self.ssl_mode = ssl_mode;
156        self
157    }
158
159    /// Gets the SSL configuration.
160    pub fn get_ssl_mode(&self) -> SslMode {
161        self.ssl_mode
162    }
163
164    /// Sets the SSL negotiation method.
165    ///
166    /// Defaults to `postgres`.
167    pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
168        self.ssl_negotiation = ssl_negotiation;
169        self
170    }
171
172    /// Gets the SSL negotiation method.
173    pub fn get_ssl_negotiation(&self) -> SslNegotiation {
174        self.ssl_negotiation
175    }
176
177    pub fn host(&mut self, host: &str) -> &mut Config {
178        if host.starts_with('/') {
179            return self.host_path(host);
180        }
181
182        let host = Host::Tcp(Box::from(host));
183
184        self.host.push(host);
185        self
186    }
187
188    /// Adds a Unix socket host to the configuration.
189    ///
190    /// Unlike `host`, this method allows non-UTF8 paths.
191    pub fn host_path<T>(&mut self, host: T) -> &mut Config
192    where
193        T: AsRef<Path>,
194    {
195        self.host.push(Host::Unix(host.as_ref().to_path_buf()));
196        self
197    }
198
199    /// Gets the hosts that have been added to the configuration with `host`.
200    pub fn get_hosts(&self) -> &[Host] {
201        &self.host
202    }
203
204    /// Adds a port to the configuration.
205    ///
206    /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
207    /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
208    /// as hosts.
209    pub fn port(&mut self, port: u16) -> &mut Config {
210        self.port.push(port);
211        self
212    }
213
214    /// Gets the ports that have been added to the configuration with `port`.
215    pub fn get_ports(&self) -> &[u16] {
216        &self.port
217    }
218
219    /// Sets the requirements of the session.
220    ///
221    /// This can be used to connect to the primary server in a clustered database rather than one of the read-only
222    /// secondary servers. Defaults to `Any`.
223    pub fn target_session_attrs(&mut self, target_session_attrs: TargetSessionAttrs) -> &mut Config {
224        self.target_session_attrs = target_session_attrs;
225        self
226    }
227
228    /// Gets the requirements of the session.
229    pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
230        self.target_session_attrs
231    }
232
233    /// change the remote peer's tls certificates. it's often coupled with [`Postgres::connect_io`] API for manual tls
234    /// session connecting and channel binding authentication.
235    /// # Examples
236    /// ```rust
237    /// use xitca_postgres::{Config, Postgres};
238    ///
239    /// // handle tls connection on your own.
240    /// async fn connect_io() {
241    ///     let mut cfg = Config::try_from("postgres://postgres:postgres@localhost/postgres").unwrap();
242    ///     
243    ///     // an imaginary function where you establish a tls connection to database on your own.
244    ///     // the established connection should be providing valid cert bytes.
245    ///     let (io, certs) = your_tls_connector().await;
246    ///
247    ///     // set cert bytes to configuration
248    ///     cfg.tls_server_end_point(certs);
249    ///
250    ///     // give xitca-postgres the config and established io and finish db session process.
251    ///     let _ = Postgres::new(cfg).connect_io(io).await;
252    /// }
253    ///
254    /// async fn your_tls_connector() -> (MyTlsStream, Vec<u8>) {
255    ///     todo!("your tls connecting logic lives here. the process can be async or not.")
256    /// }
257    ///
258    /// // a possible type representation of your manual tls connection to database
259    /// struct MyTlsStream;
260    ///
261    /// # use std::{io, pin::Pin, task::{Context, Poll}};
262    /// #
263    /// # use xitca_io::io::{AsyncIo, Interest, Ready};
264    /// #   
265    /// # impl AsyncIo for MyTlsStream {
266    /// #   async fn ready(&mut self, interest: Interest) -> io::Result<Ready> {
267    /// #       todo!()
268    /// #   }
269    /// #
270    /// #   fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
271    /// #       todo!()
272    /// #   }
273    /// #   
274    /// #   fn is_vectored_write(&self) -> bool {
275    /// #       false
276    /// #   }
277    /// #   
278    /// #   fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
279    /// #       Poll::Ready(Ok(()))
280    /// #   }
281    /// # }
282    /// #   
283    /// # impl io::Read for MyTlsStream {
284    /// #   fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
285    /// #       todo!()
286    /// #   }
287    /// # }   
288    /// #
289    /// # impl io::Write for MyTlsStream {
290    /// #   fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
291    /// #       todo!()
292    /// #   }
293    /// #   
294    /// #   fn flush(&mut self) -> io::Result<()> {
295    /// #       Ok(())
296    /// #   }
297    /// # }
298    /// ```
299    ///
300    /// [`Postgres::connect_io`]: crate::Postgres::connect_io
301    pub fn tls_server_end_point(&mut self, tls_server_end_point: impl AsRef<[u8]>) -> &mut Self {
302        self.tls_server_end_point = Some(Box::from(tls_server_end_point.as_ref()));
303        self
304    }
305
306    pub fn get_tls_server_end_point(&self) -> Option<&[u8]> {
307        self.tls_server_end_point.as_deref()
308    }
309
310    fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
311        match key {
312            "user" => {
313                self.user(value);
314            }
315            "password" => {
316                self.password(value);
317            }
318            "dbname" => {
319                self.dbname(value);
320            }
321            "options" => {
322                self.options(value);
323            }
324            "application_name" => {
325                self.application_name(value);
326            }
327            "sslmode" => {
328                let mode = match value {
329                    "disable" => SslMode::Disable,
330                    "prefer" => SslMode::Prefer,
331                    "require" => SslMode::Require,
332                    _ => return Err(Error::todo()),
333                };
334                self.ssl_mode(mode);
335            }
336            "sslnegotiation" => {
337                let mode = match value {
338                    "postgres" => SslNegotiation::Postgres,
339                    "direct" => SslNegotiation::Direct,
340                    _ => return Err(Error::todo()),
341                };
342                self.ssl_negotiation(mode);
343            }
344            "host" => {
345                for host in value.split(',') {
346                    self.host(host);
347                }
348            }
349            "port" => {
350                for port in value.split(',') {
351                    let port = if port.is_empty() {
352                        5432
353                    } else {
354                        port.parse().map_err(|_| Error::todo())?
355                    };
356                    self.port(port);
357                }
358            }
359            "target_session_attrs" => {
360                let target_session_attrs = match value {
361                    "any" => TargetSessionAttrs::Any,
362                    "read-write" => TargetSessionAttrs::ReadWrite,
363                    "read-only" => TargetSessionAttrs::ReadOnly,
364                    _ => return Err(Error::todo()),
365                };
366                self.target_session_attrs(target_session_attrs);
367            }
368            _ => {
369                return Err(Error::todo());
370            }
371        }
372
373        Ok(())
374    }
375}
376
377impl TryFrom<String> for Config {
378    type Error = Error;
379
380    fn try_from(s: String) -> Result<Self, Self::Error> {
381        Self::try_from(s.as_str())
382    }
383}
384
385impl TryFrom<&str> for Config {
386    type Error = Error;
387
388    fn try_from(s: &str) -> Result<Self, Self::Error> {
389        match UrlParser::parse(s)? {
390            Some(config) => Ok(config),
391            None => Parser::parse(s),
392        }
393    }
394}
395
396// Omit password from debug output
397impl fmt::Debug for Config {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        struct Redaction {}
400        impl fmt::Debug for Redaction {
401            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402                write!(f, "_")
403            }
404        }
405
406        f.debug_struct("Config")
407            .field("user", &self.user)
408            .field("password", &self.password.as_ref().map(|_| Redaction {}))
409            .field("dbname", &self.dbname)
410            .field("options", &self.options)
411            .field("application_name", &self.application_name)
412            .field("host", &self.host)
413            .field("port", &self.port)
414            .field("target_session_attrs", &self.target_session_attrs)
415            .finish()
416    }
417}
418
419struct Parser<'a> {
420    s: &'a str,
421    it: iter::Peekable<str::CharIndices<'a>>,
422}
423
424impl<'a> Parser<'a> {
425    fn parse(s: &'a str) -> Result<Config, Error> {
426        let mut parser = Parser {
427            s,
428            it: s.char_indices().peekable(),
429        };
430
431        let mut config = Config::new();
432
433        while let Some((key, value)) = parser.parameter()? {
434            config.param(key, &value)?;
435        }
436
437        Ok(config)
438    }
439
440    fn skip_ws(&mut self) {
441        self.take_while(char::is_whitespace);
442    }
443
444    fn take_while<F>(&mut self, f: F) -> &'a str
445    where
446        F: Fn(char) -> bool,
447    {
448        let start = match self.it.peek() {
449            Some(&(i, _)) => i,
450            None => return "",
451        };
452
453        loop {
454            match self.it.peek() {
455                Some(&(_, c)) if f(c) => {
456                    self.it.next();
457                }
458                Some(&(i, _)) => return &self.s[start..i],
459                None => return &self.s[start..],
460            }
461        }
462    }
463
464    fn eat(&mut self, target: char) -> Result<(), Error> {
465        match self.it.next() {
466            Some((_, c)) if c == target => Ok(()),
467            Some((i, c)) => {
468                let _m = format!("unexpected character at byte {i}: expected `{target}` but got `{c}`");
469                Err(Error::todo())
470            }
471            None => Err(Error::todo()),
472        }
473    }
474
475    fn eat_if(&mut self, target: char) -> bool {
476        match self.it.peek() {
477            Some(&(_, c)) if c == target => {
478                self.it.next();
479                true
480            }
481            _ => false,
482        }
483    }
484
485    fn keyword(&mut self) -> Option<&'a str> {
486        let s = self.take_while(|c| match c {
487            c if c.is_whitespace() => false,
488            '=' => false,
489            _ => true,
490        });
491
492        if s.is_empty() {
493            None
494        } else {
495            Some(s)
496        }
497    }
498
499    fn value(&mut self) -> Result<String, Error> {
500        let value = if self.eat_if('\'') {
501            let value = self.quoted_value()?;
502            self.eat('\'')?;
503            value
504        } else {
505            self.simple_value()?
506        };
507
508        Ok(value)
509    }
510
511    fn simple_value(&mut self) -> Result<String, Error> {
512        let mut value = String::new();
513
514        while let Some(&(_, c)) = self.it.peek() {
515            if c.is_whitespace() {
516                break;
517            }
518
519            self.it.next();
520            if c == '\\' {
521                if let Some((_, c2)) = self.it.next() {
522                    value.push(c2);
523                }
524            } else {
525                value.push(c);
526            }
527        }
528
529        if value.is_empty() {
530            return Err(Error::todo());
531        }
532
533        Ok(value)
534    }
535
536    fn quoted_value(&mut self) -> Result<String, Error> {
537        let mut value = String::new();
538
539        while let Some(&(_, c)) = self.it.peek() {
540            if c == '\'' {
541                return Ok(value);
542            }
543
544            self.it.next();
545            if c == '\\' {
546                if let Some((_, c2)) = self.it.next() {
547                    value.push(c2);
548                }
549            } else {
550                value.push(c);
551            }
552        }
553
554        Err(Error::todo())
555    }
556
557    fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
558        self.skip_ws();
559        let keyword = match self.keyword() {
560            Some(keyword) => keyword,
561            None => return Ok(None),
562        };
563        self.skip_ws();
564        self.eat('=')?;
565        self.skip_ws();
566        let value = self.value()?;
567
568        Ok(Some((keyword, value)))
569    }
570}
571
572// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
573struct UrlParser<'a> {
574    s: &'a str,
575    config: Config,
576}
577
578impl<'a> UrlParser<'a> {
579    fn parse(s: &'a str) -> Result<Option<Config>, Error> {
580        let s = match Self::remove_url_prefix(s) {
581            Some(s) => s,
582            None => return Ok(None),
583        };
584
585        let mut parser = UrlParser {
586            s,
587            config: Config::new(),
588        };
589
590        parser.parse_credentials()?;
591        parser.parse_host()?;
592        parser.parse_path()?;
593        parser.parse_params()?;
594
595        Ok(Some(parser.config))
596    }
597
598    fn remove_url_prefix(s: &str) -> Option<&str> {
599        for prefix in &["postgres://", "postgresql://"] {
600            if let Some(stripped) = s.strip_prefix(prefix) {
601                return Some(stripped);
602            }
603        }
604
605        None
606    }
607
608    fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
609        match self.s.find(end) {
610            Some(pos) => {
611                let (head, tail) = self.s.split_at(pos);
612                self.s = tail;
613                Some(head)
614            }
615            None => None,
616        }
617    }
618
619    fn take_all(&mut self) -> &'a str {
620        mem::take(&mut self.s)
621    }
622
623    fn eat_byte(&mut self) {
624        self.s = &self.s[1..];
625    }
626
627    fn parse_credentials(&mut self) -> Result<(), Error> {
628        let creds = match self.take_until(&['@']) {
629            Some(creds) => creds,
630            None => return Ok(()),
631        };
632        self.eat_byte();
633
634        let mut it = creds.splitn(2, ':');
635        let user = self.decode(it.next().unwrap())?;
636        self.config.user(&user);
637
638        if let Some(password) = it.next() {
639            let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
640            self.config.password(password);
641        }
642
643        Ok(())
644    }
645
646    fn parse_host(&mut self) -> Result<(), Error> {
647        let host = match self.take_until(&['/', '?']) {
648            Some(host) => host,
649            None => self.take_all(),
650        };
651
652        if host.is_empty() {
653            return Ok(());
654        }
655
656        for chunk in host.split(',') {
657            let (host, port) = if chunk.starts_with('[') {
658                let idx = match chunk.find(']') {
659                    Some(idx) => idx,
660                    None => return Err(Error::todo()),
661                };
662
663                let host = &chunk[1..idx];
664                let remaining = &chunk[idx + 1..];
665                let port = if let Some(port) = remaining.strip_prefix(':') {
666                    Some(port)
667                } else if remaining.is_empty() {
668                    None
669                } else {
670                    return Err(Error::todo());
671                };
672
673                (host, port)
674            } else {
675                let mut it = chunk.splitn(2, ':');
676                (it.next().unwrap(), it.next())
677            };
678
679            self.host_param(host)?;
680            let port = self.decode(port.unwrap_or("5432"))?;
681            self.config.param("port", &port)?;
682        }
683
684        Ok(())
685    }
686
687    fn parse_path(&mut self) -> Result<(), Error> {
688        if !self.s.starts_with('/') {
689            return Ok(());
690        }
691        self.eat_byte();
692
693        let dbname = match self.take_until(&['?']) {
694            Some(dbname) => dbname,
695            None => self.take_all(),
696        };
697
698        if !dbname.is_empty() {
699            self.config.dbname(&self.decode(dbname)?);
700        }
701
702        Ok(())
703    }
704
705    fn parse_params(&mut self) -> Result<(), Error> {
706        if !self.s.starts_with('?') {
707            return Ok(());
708        }
709        self.eat_byte();
710
711        while !self.s.is_empty() {
712            let key = match self.take_until(&['=']) {
713                Some(key) => self.decode(key)?,
714                None => return Err(Error::todo()),
715            };
716            self.eat_byte();
717
718            let value = match self.take_until(&['&']) {
719                Some(value) => {
720                    self.eat_byte();
721                    value
722                }
723                None => self.take_all(),
724            };
725
726            if key == "host" {
727                self.host_param(value)?;
728            } else {
729                let value = self.decode(value)?;
730                self.config.param(&key, &value)?;
731            }
732        }
733
734        Ok(())
735    }
736
737    fn host_param(&mut self, s: &str) -> Result<(), Error> {
738        let s = self.decode(s)?;
739        self.config.param("host", &s)
740    }
741
742    fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
743        percent_encoding::percent_decode(s.as_bytes())
744            .decode_utf8()
745            .map_err(|_| Error::todo())
746    }
747}