sqlint/connector/
postgres.rs

1mod conversion;
2mod error;
3
4use crate::{
5    ast::{Query, Value},
6    connector::{metrics, queryable::*, ResultSet, Transaction},
7    error::{Error, ErrorKind},
8    visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use futures::{future::FutureExt, lock::Mutex};
12use lru_cache::LruCache;
13use native_tls::{Certificate, Identity, TlsConnector};
14use percent_encoding::percent_decode;
15use postgres_native_tls::MakeTlsConnector;
16use std::{
17    borrow::{Borrow, Cow},
18    fmt::{Debug, Display},
19    fs,
20    future::Future,
21    sync::atomic::{AtomicBool, Ordering},
22    time::Duration,
23};
24use tokio_postgres::{
25    config::{ChannelBinding, SslMode},
26    Client, Config, Statement,
27};
28use url::Url;
29
30pub(crate) const DEFAULT_SCHEMA: &str = "public";
31
32/// The underlying postgres driver. Only available with the `expose-drivers`
33/// Cargo feature.
34#[cfg(feature = "expose-drivers")]
35pub use tokio_postgres;
36
37use super::IsolationLevel;
38
39#[derive(Clone)]
40struct Hidden<T>(T);
41
42impl<T> Debug for Hidden<T> {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.write_str("<HIDDEN>")
45    }
46}
47
48struct PostgresClient(Client);
49
50impl Debug for PostgresClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.write_str("PostgresClient")
53    }
54}
55
56/// A connector interface for the PostgreSQL database.
57#[derive(Debug)]
58#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
59pub struct PostgreSql {
60    client: PostgresClient,
61    pg_bouncer: bool,
62    socket_timeout: Option<Duration>,
63    statement_cache: Mutex<LruCache<String, Statement>>,
64    is_healthy: AtomicBool,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
69pub enum SslAcceptMode {
70    Strict,
71    AcceptInvalidCerts,
72}
73
74#[derive(Debug, Clone)]
75#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
76pub struct SslParams {
77    certificate_file: Option<String>,
78    identity_file: Option<String>,
79    identity_password: Hidden<Option<String>>,
80    ssl_accept_mode: SslAcceptMode,
81}
82
83#[derive(Debug)]
84struct SslAuth {
85    certificate: Hidden<Option<Certificate>>,
86    identity: Hidden<Option<Identity>>,
87    ssl_accept_mode: SslAcceptMode,
88}
89
90impl Default for SslAuth {
91    fn default() -> Self {
92        Self { certificate: Hidden(None), identity: Hidden(None), ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts }
93    }
94}
95
96impl SslAuth {
97    fn certificate(&mut self, certificate: Certificate) -> &mut Self {
98        self.certificate = Hidden(Some(certificate));
99        self
100    }
101
102    fn identity(&mut self, identity: Identity) -> &mut Self {
103        self.identity = Hidden(Some(identity));
104        self
105    }
106
107    fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self {
108        self.ssl_accept_mode = mode;
109        self
110    }
111}
112
113impl SslParams {
114    async fn into_auth(self) -> crate::Result<SslAuth> {
115        let mut auth = SslAuth::default();
116        auth.accept_mode(self.ssl_accept_mode);
117
118        if let Some(ref cert_file) = self.certificate_file {
119            let cert = fs::read(cert_file).map_err(|err| {
120                Error::builder(ErrorKind::TlsError { message: format!("cert file not found ({err})") }).build()
121            })?;
122
123            auth.certificate(Certificate::from_pem(&cert)?);
124        }
125
126        if let Some(ref identity_file) = self.identity_file {
127            let db = fs::read(identity_file).map_err(|err| {
128                Error::builder(ErrorKind::TlsError { message: format!("identity file not found ({err})") }).build()
129            })?;
130            let password = self.identity_password.0.as_deref().unwrap_or("");
131            let identity = Identity::from_pkcs12(&db, password)?;
132
133            auth.identity(identity);
134        }
135
136        Ok(auth)
137    }
138}
139
140#[derive(Debug, Clone, Copy)]
141pub enum PostgresFlavour {
142    Postgres,
143    Cockroach,
144    Unknown,
145}
146
147impl PostgresFlavour {
148    /// Returns `true` if the postgres flavour is [`Postgres`].
149    ///
150    /// [`Postgres`]: PostgresFlavour::Postgres
151    fn is_postgres(&self) -> bool {
152        matches!(self, Self::Postgres)
153    }
154
155    /// Returns `true` if the postgres flavour is [`Cockroach`].
156    ///
157    /// [`Cockroach`]: PostgresFlavour::Cockroach
158    fn is_cockroach(&self) -> bool {
159        matches!(self, Self::Cockroach)
160    }
161
162    /// Returns `true` if the postgres flavour is [`Unknown`].
163    ///
164    /// [`Unknown`]: PostgresFlavour::Unknown
165    fn is_unknown(&self) -> bool {
166        matches!(self, Self::Unknown)
167    }
168}
169
170/// Wraps a connection url and exposes the parsing logic used by Sqlint,
171/// including default values.
172#[derive(Debug, Clone)]
173#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
174pub struct PostgresUrl {
175    url: Url,
176    query_params: PostgresUrlQueryParams,
177    flavour: PostgresFlavour,
178}
179
180impl PostgresUrl {
181    /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection
182    /// parameters.
183    pub fn new(url: Url) -> Result<Self, Error> {
184        let query_params = Self::parse_query_params(&url)?;
185
186        Ok(Self { url, query_params, flavour: PostgresFlavour::Unknown })
187    }
188
189    /// The bare `Url` to the database.
190    pub fn url(&self) -> &Url {
191        &self.url
192    }
193
194    /// The percent-decoded database username.
195    pub fn username(&self) -> Cow<str> {
196        match percent_decode(self.url.username().as_bytes()).decode_utf8() {
197            Ok(username) => username,
198            Err(_) => {
199                tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
200
201                self.url.username().into()
202            }
203        }
204    }
205
206    /// The database host. Taken first from the `host` query parameter, then
207    /// from the `host` part of the URL. For socket connections, the query
208    /// parameter must be used.
209    ///
210    /// If none of them are set, defaults to `localhost`.
211    pub fn host(&self) -> &str {
212        match (self.query_params.host.as_ref(), self.url.host_str()) {
213            (Some(host), _) => host.as_str(),
214            (None, Some("")) => "localhost",
215            (None, None) => "localhost",
216            (None, Some(host)) => host,
217        }
218    }
219
220    /// Name of the database connected. Defaults to `postgres`.
221    pub fn dbname(&self) -> &str {
222        match self.url.path_segments() {
223            Some(mut segments) => segments.next().unwrap_or("postgres"),
224            None => "postgres",
225        }
226    }
227
228    /// The percent-decoded database password.
229    pub fn password(&self) -> Cow<str> {
230        match self.url.password().and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) {
231            Some(password) => password,
232            None => self.url.password().unwrap_or("").into(),
233        }
234    }
235
236    /// The database port, defaults to `5432`.
237    pub fn port(&self) -> u16 {
238        self.url.port().unwrap_or(5432)
239    }
240
241    /// The database schema, defaults to `public`.
242    pub fn schema(&self) -> &str {
243        self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA)
244    }
245
246    /// Whether the pgbouncer mode is enabled.
247    pub fn pg_bouncer(&self) -> bool {
248        self.query_params.pg_bouncer
249    }
250
251    /// The connection timeout.
252    pub fn connect_timeout(&self) -> Option<Duration> {
253        self.query_params.connect_timeout
254    }
255
256    /// Pool check_out timeout
257    pub fn pool_timeout(&self) -> Option<Duration> {
258        self.query_params.pool_timeout
259    }
260
261    /// The socket timeout
262    pub fn socket_timeout(&self) -> Option<Duration> {
263        self.query_params.socket_timeout
264    }
265
266    /// The maximum connection lifetime
267    pub fn max_connection_lifetime(&self) -> Option<Duration> {
268        self.query_params.max_connection_lifetime
269    }
270
271    /// The maximum idle connection lifetime
272    pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
273        self.query_params.max_idle_connection_lifetime
274    }
275
276    /// The custom application name
277    pub fn application_name(&self) -> Option<&str> {
278        self.query_params.application_name.as_deref()
279    }
280
281    pub fn channel_binding(&self) -> ChannelBinding {
282        self.query_params.channel_binding
283    }
284
285    pub(crate) fn cache(&self) -> LruCache<String, Statement> {
286        if self.query_params.pg_bouncer {
287            LruCache::new(0)
288        } else {
289            LruCache::new(self.query_params.statement_cache_size)
290        }
291    }
292
293    pub(crate) fn options(&self) -> Option<&str> {
294        self.query_params.options.as_deref()
295    }
296
297    /// Sets whether the URL points to a Postgres, Cockroach or Unknown database.
298    /// This is used to avoid a network roundtrip at connection to set the search path.
299    ///
300    /// The different behaviours are:
301    /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters.
302    /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query.
303    /// - Unknown: Always add a network roundtrip by setting the search path through a database query.
304    pub fn set_flavour(&mut self, flavour: PostgresFlavour) {
305        self.flavour = flavour;
306    }
307
308    fn parse_query_params(url: &Url) -> Result<PostgresUrlQueryParams, Error> {
309        let mut connection_limit = None;
310        let mut schema = None;
311        let mut certificate_file = None;
312        let mut identity_file = None;
313        let mut identity_password = None;
314        let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts;
315        let mut ssl_mode = SslMode::Prefer;
316        let mut host = None;
317        let mut application_name = None;
318        let mut channel_binding = ChannelBinding::Prefer;
319        let mut socket_timeout = None;
320        let mut connect_timeout = Some(Duration::from_secs(5));
321        let mut pool_timeout = Some(Duration::from_secs(10));
322        let mut pg_bouncer = false;
323        let mut statement_cache_size = 100;
324        let mut max_connection_lifetime = None;
325        let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
326        let mut options = None;
327
328        for (k, v) in url.query_pairs() {
329            match k.as_ref() {
330                "pgbouncer" => {
331                    pg_bouncer =
332                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
333                }
334                "sslmode" => {
335                    match v.as_ref() {
336                        "disable" => ssl_mode = SslMode::Disable,
337                        "prefer" => ssl_mode = SslMode::Prefer,
338                        "require" => ssl_mode = SslMode::Require,
339                        _ => {
340                            tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v);
341                        }
342                    };
343                }
344                "sslcert" => {
345                    certificate_file = Some(v.to_string());
346                }
347                "sslidentity" => {
348                    identity_file = Some(v.to_string());
349                }
350                "sslpassword" => {
351                    identity_password = Some(v.to_string());
352                }
353                "statement_cache_size" => {
354                    statement_cache_size =
355                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
356                }
357                "sslaccept" => {
358                    match v.as_ref() {
359                        "strict" => {
360                            ssl_accept_mode = SslAcceptMode::Strict;
361                        }
362                        "accept_invalid_certs" => {
363                            ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts;
364                        }
365                        _ => {
366                            tracing::debug!(
367                                message = "Unsupported SSL accept mode, defaulting to `strict`",
368                                mode = &*v
369                            );
370
371                            ssl_accept_mode = SslAcceptMode::Strict;
372                        }
373                    };
374                }
375                "schema" => {
376                    schema = Some(v.to_string());
377                }
378                "connection_limit" => {
379                    let as_int: usize =
380                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
381                    connection_limit = Some(as_int);
382                }
383                "host" => {
384                    host = Some(v.to_string());
385                }
386                "socket_timeout" => {
387                    let as_int =
388                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
389                    socket_timeout = Some(Duration::from_secs(as_int));
390                }
391                "connect_timeout" => {
392                    let as_int =
393                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
394
395                    if as_int == 0 {
396                        connect_timeout = None;
397                    } else {
398                        connect_timeout = Some(Duration::from_secs(as_int));
399                    }
400                }
401                "pool_timeout" => {
402                    let as_int =
403                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
404
405                    if as_int == 0 {
406                        pool_timeout = None;
407                    } else {
408                        pool_timeout = Some(Duration::from_secs(as_int));
409                    }
410                }
411                "max_connection_lifetime" => {
412                    let as_int =
413                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
414
415                    if as_int == 0 {
416                        max_connection_lifetime = None;
417                    } else {
418                        max_connection_lifetime = Some(Duration::from_secs(as_int));
419                    }
420                }
421                "max_idle_connection_lifetime" => {
422                    let as_int =
423                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
424
425                    if as_int == 0 {
426                        max_idle_connection_lifetime = None;
427                    } else {
428                        max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
429                    }
430                }
431                "application_name" => {
432                    application_name = Some(v.to_string());
433                }
434                "channel_binding" => {
435                    match v.as_ref() {
436                        "disable" => channel_binding = ChannelBinding::Disable,
437                        "prefer" => channel_binding = ChannelBinding::Prefer,
438                        "require" => channel_binding = ChannelBinding::Require,
439                        _ => {
440                            tracing::debug!(
441                                message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`",
442                                channel_binding = &*v
443                            );
444                        }
445                    };
446                }
447                "options" => {
448                    options = Some(v.to_string());
449                }
450                _ => {
451                    tracing::trace!(message = "Discarding connection string param", param = &*k);
452                }
453            };
454        }
455
456        Ok(PostgresUrlQueryParams {
457            ssl_params: SslParams {
458                certificate_file,
459                identity_file,
460                ssl_accept_mode,
461                identity_password: Hidden(identity_password),
462            },
463            connection_limit,
464            schema,
465            ssl_mode,
466            host,
467            connect_timeout,
468            pool_timeout,
469            socket_timeout,
470            pg_bouncer,
471            statement_cache_size,
472            max_connection_lifetime,
473            max_idle_connection_lifetime,
474            application_name,
475            channel_binding,
476            options,
477        })
478    }
479
480    pub(crate) fn ssl_params(&self) -> &SslParams {
481        &self.query_params.ssl_params
482    }
483
484    #[cfg(feature = "pooled")]
485    pub(crate) fn connection_limit(&self) -> Option<usize> {
486        self.query_params.connection_limit
487    }
488
489    /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection.
490    /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed.
491    /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter.
492    fn set_search_path(&self, config: &mut Config) {
493        // PGBouncer does not support the search_path connection parameter.
494        // https://www.pgbouncer.org/config.html#ignore_startup_parameters
495        if self.query_params.pg_bouncer {
496            return;
497        }
498
499        if let Some(schema) = &self.query_params.schema {
500            if self.flavour().is_cockroach() && is_safe_identifier(schema) {
501                config.search_path(CockroachSearchPath(schema).to_string());
502            }
503
504            if self.flavour().is_postgres() {
505                config.search_path(PostgresSearchPath(schema).to_string());
506            }
507        }
508    }
509
510    pub(crate) fn to_config(&self) -> Config {
511        let mut config = Config::new();
512
513        config.user(self.username().borrow());
514        config.password(self.password().borrow() as &str);
515        config.host(self.host());
516        config.port(self.port());
517        config.dbname(self.dbname());
518        config.pgbouncer_mode(self.query_params.pg_bouncer);
519
520        if let Some(options) = self.options() {
521            config.options(options);
522        }
523
524        if let Some(application_name) = self.application_name() {
525            config.application_name(application_name);
526        }
527
528        if let Some(connect_timeout) = self.query_params.connect_timeout {
529            config.connect_timeout(connect_timeout);
530        }
531
532        self.set_search_path(&mut config);
533
534        config.ssl_mode(self.query_params.ssl_mode);
535
536        config.channel_binding(self.query_params.channel_binding);
537
538        config
539    }
540
541    pub fn flavour(&self) -> PostgresFlavour {
542        self.flavour
543    }
544}
545
546#[derive(Debug, Clone)]
547pub(crate) struct PostgresUrlQueryParams {
548    ssl_params: SslParams,
549    connection_limit: Option<usize>,
550    schema: Option<String>,
551    ssl_mode: SslMode,
552    pg_bouncer: bool,
553    host: Option<String>,
554    socket_timeout: Option<Duration>,
555    connect_timeout: Option<Duration>,
556    pool_timeout: Option<Duration>,
557    statement_cache_size: usize,
558    max_connection_lifetime: Option<Duration>,
559    max_idle_connection_lifetime: Option<Duration>,
560    application_name: Option<String>,
561    channel_binding: ChannelBinding,
562    options: Option<String>,
563}
564
565impl PostgreSql {
566    /// Create a new connection to the database.
567    pub async fn new(url: PostgresUrl) -> crate::Result<Self> {
568        let config = url.to_config();
569
570        let mut tls_builder = TlsConnector::builder();
571
572        {
573            let ssl_params = url.ssl_params();
574            let auth = ssl_params.to_owned().into_auth().await?;
575
576            if let Some(certificate) = auth.certificate.0 {
577                tls_builder.add_root_certificate(certificate);
578            }
579
580            tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts);
581
582            if let Some(identity) = auth.identity.0 {
583                tls_builder.identity(identity);
584            }
585        }
586
587        let tls = MakeTlsConnector::new(tls_builder.build()?);
588        let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?;
589
590        tokio::spawn(conn.map(|r| match r {
591            Ok(_) => (),
592            Err(e) => {
593                tracing::error!("Error in PostgreSQL connection: {:?}", e);
594            }
595        }));
596
597        // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection.
598        // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed.
599        // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter.
600        // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown.
601        if let Some(schema) = &url.query_params.schema {
602            // PGBouncer does not support the search_path connection parameter.
603            // https://www.pgbouncer.org/config.html#ignore_startup_parameters
604            if url.query_params.pg_bouncer
605                || url.flavour().is_unknown()
606                || (url.flavour().is_cockroach() && !is_safe_identifier(schema))
607            {
608                let session_variables = format!(
609                    r##"{set_search_path}"##,
610                    set_search_path = SetSearchPath(url.query_params.schema.as_deref())
611                );
612
613                client.simple_query(session_variables.as_str()).await?;
614            }
615        }
616
617        Ok(Self {
618            client: PostgresClient(client),
619            socket_timeout: url.query_params.socket_timeout,
620            pg_bouncer: url.query_params.pg_bouncer,
621            statement_cache: Mutex::new(url.cache()),
622            is_healthy: AtomicBool::new(true),
623        })
624    }
625
626    /// The underlying tokio_postgres::Client. Only available with the
627    /// `expose-drivers` Cargo feature. This is a lower level API when you need
628    /// to get into database specific features.
629    #[cfg(feature = "expose-drivers")]
630    pub fn client(&self) -> &tokio_postgres::Client {
631        &self.client.0
632    }
633
634    async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<Statement> {
635        let mut cache = self.statement_cache.lock().await;
636        let capacity = cache.capacity();
637        let stored = cache.len();
638
639        match cache.get_mut(sql) {
640            Some(stmt) => {
641                tracing::trace!(message = "CACHE HIT!", query = sql, capacity = capacity, stored = stored,);
642
643                Ok(stmt.clone()) // arc'd
644            }
645            None => {
646                tracing::trace!(message = "CACHE MISS!", query = sql, capacity = capacity, stored = stored,);
647
648                let param_types = conversion::params_to_types(params);
649                let stmt = self.perform_io(self.client.0.prepare_typed(sql, &param_types)).await?;
650
651                cache.insert(sql.to_string(), stmt.clone());
652
653                Ok(stmt)
654            }
655        }
656    }
657
658    async fn perform_io<F, T>(&self, fut: F) -> crate::Result<T>
659    where
660        F: Future<Output = Result<T, tokio_postgres::Error>>,
661    {
662        match super::timeout::socket(self.socket_timeout, fut).await {
663            Err(e) if e.is_closed() => {
664                self.is_healthy.store(false, Ordering::SeqCst);
665                Err(e)
666            }
667            res => res,
668        }
669    }
670
671    fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> {
672        if params.len() > i16::MAX as usize {
673            // tokio_postgres would return an error here. Let's avoid calling the driver
674            // and return an error early.
675            let kind = ErrorKind::QueryInvalidInput(format!(
676                "too many bind variables in prepared statement, expected maximum of {}, received {}",
677                i16::MAX,
678                params.len()
679            ));
680            Err(Error::builder(kind).build())
681        } else {
682            Ok(())
683        }
684    }
685}
686
687// A SearchPath connection parameter (Display-impl) for connection initialization.
688struct CockroachSearchPath<'a>(&'a str);
689
690impl Display for CockroachSearchPath<'_> {
691    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692        f.write_str(self.0)
693    }
694}
695
696// A SearchPath connection parameter (Display-impl) for connection initialization.
697struct PostgresSearchPath<'a>(&'a str);
698
699impl Display for PostgresSearchPath<'_> {
700    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701        f.write_str("\"")?;
702        f.write_str(self.0)?;
703        f.write_str("\"")?;
704
705        Ok(())
706    }
707}
708
709// A SetSearchPath statement (Display-impl) for connection initialization.
710struct SetSearchPath<'a>(Option<&'a str>);
711
712impl Display for SetSearchPath<'_> {
713    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714        if let Some(schema) = self.0 {
715            f.write_str("SET search_path = \"")?;
716            f.write_str(schema)?;
717            f.write_str("\";\n")?;
718        }
719
720        Ok(())
721    }
722}
723
724impl TransactionCapable for PostgreSql {}
725
726#[async_trait]
727impl Queryable for PostgreSql {
728    async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
729        let (sql, params) = visitor::Postgres::build(q)?;
730
731        self.query_raw(sql.as_str(), &params[..]).await
732    }
733
734    async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
735        self.check_bind_variables_len(params)?;
736
737        metrics::query("postgres.query_raw", sql, params, move || async move {
738            let stmt = self.fetch_cached(sql, &[]).await?;
739
740            if stmt.params().len() != params.len() {
741                let kind =
742                    ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
743
744                return Err(Error::builder(kind).build());
745            }
746
747            let rows = self.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())).await?;
748
749            let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
750
751            for row in rows {
752                result.rows.push(row.get_result_row()?);
753            }
754
755            Ok(result)
756        })
757        .await
758    }
759
760    async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
761        self.check_bind_variables_len(params)?;
762
763        metrics::query("postgres.query_raw", sql, params, move || async move {
764            let stmt = self.fetch_cached(sql, params).await?;
765
766            if stmt.params().len() != params.len() {
767                let kind =
768                    ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
769
770                return Err(Error::builder(kind).build());
771            }
772
773            let rows = self.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())).await?;
774
775            let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
776
777            for row in rows {
778                result.rows.push(row.get_result_row()?);
779            }
780
781            Ok(result)
782        })
783        .await
784    }
785
786    async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
787        let (sql, params) = visitor::Postgres::build(q)?;
788
789        self.execute_raw(sql.as_str(), &params[..]).await
790    }
791
792    async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
793        self.check_bind_variables_len(params)?;
794
795        metrics::query("postgres.execute_raw", sql, params, move || async move {
796            let stmt = self.fetch_cached(sql, &[]).await?;
797
798            if stmt.params().len() != params.len() {
799                let kind =
800                    ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
801
802                return Err(Error::builder(kind).build());
803            }
804
805            let changes =
806                self.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())).await?;
807
808            Ok(changes)
809        })
810        .await
811    }
812
813    async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
814        self.check_bind_variables_len(params)?;
815
816        metrics::query("postgres.execute_raw", sql, params, move || async move {
817            let stmt = self.fetch_cached(sql, params).await?;
818
819            if stmt.params().len() != params.len() {
820                let kind =
821                    ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
822
823                return Err(Error::builder(kind).build());
824            }
825
826            let changes =
827                self.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())).await?;
828
829            Ok(changes)
830        })
831        .await
832    }
833
834    async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
835        metrics::query("postgres.raw_cmd", cmd, &[], move || async move {
836            self.perform_io(self.client.0.simple_query(cmd)).await?;
837            Ok(())
838        })
839        .await
840    }
841
842    async fn version(&self) -> crate::Result<Option<String>> {
843        let query = r#"SELECT version()"#;
844        let rows = self.query_raw(query, &[]).await?;
845
846        let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
847
848        Ok(version_string)
849    }
850
851    fn is_healthy(&self) -> bool {
852        self.is_healthy.load(Ordering::SeqCst)
853    }
854
855    async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> {
856        if self.pg_bouncer {
857            tx.raw_cmd("DEALLOCATE ALL").await
858        } else {
859            Ok(())
860        }
861    }
862
863    async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
864        if matches!(isolation_level, IsolationLevel::Snapshot) {
865            return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
866        }
867
868        self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
869
870        Ok(())
871    }
872
873    fn requires_isolation_first(&self) -> bool {
874        false
875    }
876}
877
878/// Sorted list of CockroachDB's reserved keywords.
879/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords
880const RESERVED_KEYWORDS: [&str; 79] = [
881    "all",
882    "analyse",
883    "analyze",
884    "and",
885    "any",
886    "array",
887    "as",
888    "asc",
889    "asymmetric",
890    "both",
891    "case",
892    "cast",
893    "check",
894    "collate",
895    "column",
896    "concurrently",
897    "constraint",
898    "create",
899    "current_catalog",
900    "current_date",
901    "current_role",
902    "current_schema",
903    "current_time",
904    "current_timestamp",
905    "current_user",
906    "default",
907    "deferrable",
908    "desc",
909    "distinct",
910    "do",
911    "else",
912    "end",
913    "except",
914    "false",
915    "fetch",
916    "for",
917    "foreign",
918    "from",
919    "grant",
920    "group",
921    "having",
922    "in",
923    "initially",
924    "intersect",
925    "into",
926    "lateral",
927    "leading",
928    "limit",
929    "localtime",
930    "localtimestamp",
931    "not",
932    "null",
933    "offset",
934    "on",
935    "only",
936    "or",
937    "order",
938    "placing",
939    "primary",
940    "references",
941    "returning",
942    "select",
943    "session_user",
944    "some",
945    "symmetric",
946    "table",
947    "then",
948    "to",
949    "trailing",
950    "true",
951    "union",
952    "unique",
953    "user",
954    "using",
955    "variadic",
956    "when",
957    "where",
958    "window",
959    "with",
960];
961
962/// Sorted list of CockroachDB's reserved type function names.
963/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords
964const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [
965    "authorization",
966    "collation",
967    "cross",
968    "full",
969    "ilike",
970    "inner",
971    "is",
972    "isnull",
973    "join",
974    "left",
975    "like",
976    "natural",
977    "none",
978    "notnull",
979    "outer",
980    "overlaps",
981    "right",
982    "similar",
983];
984
985/// Returns true if a Postgres identifier is considered "safe".
986///
987/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted.
988///
989/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
990/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers
991fn is_safe_identifier(ident: &str) -> bool {
992    if ident.is_empty() {
993        return false;
994    }
995
996    // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords.
997    if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() {
998        return false;
999    }
1000
1001    let mut chars = ident.chars();
1002
1003    let first = chars.next().unwrap();
1004
1005    // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_).
1006    if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' {
1007        return false;
1008    }
1009
1010    for c in chars {
1011        // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($).
1012        if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' {
1013            return false;
1014        }
1015    }
1016
1017    true
1018}
1019
1020#[cfg(test)]
1021mod tests {
1022    use super::*;
1023    use crate::tests::test_api::postgres::CONN_STR;
1024    use crate::tests::test_api::CRDB_CONN_STR;
1025    use crate::{connector::Queryable, error::*, single::Sqlint};
1026    use url::Url;
1027
1028    #[test]
1029    fn should_parse_socket_url() {
1030        let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap();
1031        assert_eq!("dbname", url.dbname());
1032        assert_eq!("/var/run/psql.sock", url.host());
1033    }
1034
1035    #[test]
1036    fn should_parse_escaped_url() {
1037        let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap();
1038        assert_eq!("dbname", url.dbname());
1039        assert_eq!("/var/run/postgresql", url.host());
1040    }
1041
1042    #[test]
1043    fn should_allow_changing_of_cache_size() {
1044        let url =
1045            PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap();
1046        assert_eq!(420, url.cache().capacity());
1047    }
1048
1049    #[test]
1050    fn should_have_default_cache_size() {
1051        let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap();
1052        assert_eq!(100, url.cache().capacity());
1053    }
1054
1055    #[test]
1056    fn should_have_application_name() {
1057        let url =
1058            PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap();
1059        assert_eq!(Some("test"), url.application_name());
1060    }
1061
1062    #[test]
1063    fn should_have_channel_binding() {
1064        let url =
1065            PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap();
1066        assert_eq!(ChannelBinding::Require, url.channel_binding());
1067    }
1068
1069    #[test]
1070    fn should_have_default_channel_binding() {
1071        let url =
1072            PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap();
1073        assert_eq!(ChannelBinding::Prefer, url.channel_binding());
1074
1075        let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap();
1076        assert_eq!(ChannelBinding::Prefer, url.channel_binding());
1077    }
1078
1079    #[test]
1080    fn should_not_enable_caching_with_pgbouncer() {
1081        let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap();
1082        assert_eq!(0, url.cache().capacity());
1083    }
1084
1085    #[test]
1086    fn should_parse_default_host() {
1087        let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap();
1088        assert_eq!("dbname", url.dbname());
1089        assert_eq!("localhost", url.host());
1090    }
1091
1092    #[test]
1093    fn should_handle_options_field() {
1094        let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap())
1095            .unwrap();
1096
1097        assert_eq!("--cluster=my_cluster", url.options().unwrap());
1098    }
1099
1100    #[tokio::test]
1101    async fn test_custom_search_path_pg() {
1102        async fn test_path(schema_name: &str) -> Option<String> {
1103            let mut url = Url::parse(&CONN_STR).unwrap();
1104            url.query_pairs_mut().append_pair("schema", schema_name);
1105
1106            let mut pg_url = PostgresUrl::new(url).unwrap();
1107            pg_url.set_flavour(PostgresFlavour::Postgres);
1108
1109            let client = PostgreSql::new(pg_url).await.unwrap();
1110
1111            let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1112            let row = result_set.first().unwrap();
1113
1114            row[0].to_string()
1115        }
1116
1117        // Safe
1118        assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\""));
1119        assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\""));
1120        assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1121        assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\""));
1122        assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1123        assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1124        assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1125
1126        // Not safe
1127        assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1128        assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1129        assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1130        assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1131        assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1132        assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1133        assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1134        assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1135        assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1136        assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1137        assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1138        assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1139        assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1140        assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1141
1142        for ident in RESERVED_KEYWORDS {
1143            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1144        }
1145
1146        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1147            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1148        }
1149    }
1150
1151    #[tokio::test]
1152    async fn test_custom_search_path_pg_pgbouncer() {
1153        async fn test_path(schema_name: &str) -> Option<String> {
1154            let mut url = Url::parse(&CONN_STR).unwrap();
1155            url.query_pairs_mut().append_pair("schema", schema_name);
1156            url.query_pairs_mut().append_pair("pbbouncer", "true");
1157
1158            let mut pg_url = PostgresUrl::new(url).unwrap();
1159            pg_url.set_flavour(PostgresFlavour::Postgres);
1160
1161            let client = PostgreSql::new(pg_url).await.unwrap();
1162
1163            let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1164            let row = result_set.first().unwrap();
1165
1166            row[0].to_string()
1167        }
1168
1169        // Safe
1170        assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\""));
1171        assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\""));
1172        assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1173        assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\""));
1174        assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1175        assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1176        assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1177
1178        // Not safe
1179        assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1180        assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1181        assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1182        assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1183        assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1184        assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1185        assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1186        assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1187        assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1188        assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1189        assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1190        assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1191        assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1192        assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1193
1194        for ident in RESERVED_KEYWORDS {
1195            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1196        }
1197
1198        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1199            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1200        }
1201    }
1202
1203    #[tokio::test]
1204    async fn test_custom_search_path_crdb() {
1205        async fn test_path(schema_name: &str) -> Option<String> {
1206            let mut url = Url::parse(&CRDB_CONN_STR).unwrap();
1207            url.query_pairs_mut().append_pair("schema", schema_name);
1208
1209            let mut pg_url = PostgresUrl::new(url).unwrap();
1210            pg_url.set_flavour(PostgresFlavour::Cockroach);
1211
1212            let client = PostgreSql::new(pg_url).await.unwrap();
1213
1214            let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1215            let row = result_set.first().unwrap();
1216
1217            row[0].to_string()
1218        }
1219
1220        // Safe
1221        assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1222        assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1223        assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra"));
1224        assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1225        assert_eq!(test_path("héllo").await.as_deref(), Some("héllo"));
1226        assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$"));
1227        assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$"));
1228
1229        // Not safe
1230        assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1231        assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1232        assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1233        assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1234        assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1235        assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1236        assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1237        assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1238        assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1239        assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1240        assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1241        assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1242        assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1243        assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1244
1245        for ident in RESERVED_KEYWORDS {
1246            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1247        }
1248
1249        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1250            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1251        }
1252    }
1253
1254    #[tokio::test]
1255    async fn test_custom_search_path_unknown_pg() {
1256        async fn test_path(schema_name: &str) -> Option<String> {
1257            let mut url = Url::parse(&CONN_STR).unwrap();
1258            url.query_pairs_mut().append_pair("schema", schema_name);
1259
1260            let mut pg_url = PostgresUrl::new(url).unwrap();
1261            pg_url.set_flavour(PostgresFlavour::Unknown);
1262
1263            let client = PostgreSql::new(pg_url).await.unwrap();
1264
1265            let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1266            let row = result_set.first().unwrap();
1267
1268            row[0].to_string()
1269        }
1270
1271        // Safe
1272        assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1273        assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1274        assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1275        assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1276        assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1277        assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1278        assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1279
1280        // Not safe
1281        assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1282        assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1283        assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1284        assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1285        assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1286        assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1287        assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1288        assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1289        assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1290        assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1291        assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1292        assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1293        assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1294        assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1295
1296        for ident in RESERVED_KEYWORDS {
1297            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1298        }
1299
1300        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1301            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1302        }
1303    }
1304
1305    #[tokio::test]
1306    async fn test_custom_search_path_unknown_crdb() {
1307        async fn test_path(schema_name: &str) -> Option<String> {
1308            let mut url = Url::parse(&CONN_STR).unwrap();
1309            url.query_pairs_mut().append_pair("schema", schema_name);
1310
1311            let mut pg_url = PostgresUrl::new(url).unwrap();
1312            pg_url.set_flavour(PostgresFlavour::Unknown);
1313
1314            let client = PostgreSql::new(pg_url).await.unwrap();
1315
1316            let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1317            let row = result_set.first().unwrap();
1318
1319            row[0].to_string()
1320        }
1321
1322        // Safe
1323        assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1324        assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1325        assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1326        assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1327        assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1328        assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1329        assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1330
1331        // Not safe
1332        assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1333        assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1334        assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1335        assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1336        assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1337        assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1338        assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1339        assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1340        assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1341        assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1342        assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1343        assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1344        assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1345        assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1346
1347        for ident in RESERVED_KEYWORDS {
1348            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1349        }
1350
1351        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1352            assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1353        }
1354    }
1355
1356    #[tokio::test]
1357    async fn should_map_nonexisting_database_error() {
1358        let mut url = Url::parse(&CONN_STR).unwrap();
1359        url.set_path("/this_does_not_exist");
1360
1361        let res = Sqlint::new(url.as_str()).await;
1362
1363        assert!(res.is_err());
1364
1365        match res {
1366            Ok(_) => unreachable!(),
1367            Err(e) => match e.kind() {
1368                ErrorKind::DatabaseDoesNotExist { db_name } => {
1369                    assert_eq!(Some("3D000"), e.original_code());
1370                    assert_eq!(Some("database \"this_does_not_exist\" does not exist"), e.original_message());
1371                    assert_eq!(&Name::available("this_does_not_exist"), db_name)
1372                }
1373                kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind),
1374            },
1375        }
1376    }
1377
1378    #[tokio::test]
1379    async fn should_map_wrong_credentials_error() {
1380        let mut url = Url::parse(&CONN_STR).unwrap();
1381        url.set_username("WRONG").unwrap();
1382
1383        let res = Sqlint::new(url.as_str()).await;
1384        assert!(res.is_err());
1385
1386        let err = res.unwrap_err();
1387        assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
1388    }
1389
1390    #[tokio::test]
1391    async fn should_map_tls_errors() {
1392        let mut url = Url::parse(&CONN_STR).expect("parsing url");
1393        url.set_query(Some("sslmode=require&sslaccept=strict"));
1394
1395        let res = Sqlint::new(url.as_str()).await;
1396
1397        assert!(res.is_err());
1398
1399        match res {
1400            Ok(_) => unreachable!(),
1401            Err(e) => match e.kind() {
1402                ErrorKind::TlsError { .. } => (),
1403                other => panic!("{:#?}", other),
1404            },
1405        }
1406    }
1407
1408    #[tokio::test]
1409    async fn should_map_incorrect_parameters_error() {
1410        let url = Url::parse(&CONN_STR).unwrap();
1411        let conn = Sqlint::new(url.as_str()).await.unwrap();
1412
1413        let res = conn.query_raw("SELECT $1", &[Value::integer(1), Value::integer(2)]).await;
1414
1415        assert!(res.is_err());
1416
1417        match res {
1418            Ok(_) => unreachable!(),
1419            Err(e) => match e.kind() {
1420                ErrorKind::IncorrectNumberOfParameters { expected, actual } => {
1421                    assert_eq!(1, *expected);
1422                    assert_eq!(2, *actual);
1423                }
1424                other => panic!("{:#?}", other),
1425            },
1426        }
1427    }
1428
1429    #[test]
1430    fn test_safe_ident() {
1431        // Safe
1432        assert_eq!(is_safe_identifier("hello"), true);
1433        assert_eq!(is_safe_identifier("_hello"), true);
1434        assert_eq!(is_safe_identifier("àbracadabra"), true);
1435        assert_eq!(is_safe_identifier("h3ll0"), true);
1436        assert_eq!(is_safe_identifier("héllo"), true);
1437        assert_eq!(is_safe_identifier("héll0$"), true);
1438        assert_eq!(is_safe_identifier("héll_0$"), true);
1439        assert_eq!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m"), true);
1440
1441        // Not safe
1442        assert_eq!(is_safe_identifier(""), false);
1443        assert_eq!(is_safe_identifier("Hello"), false);
1444        assert_eq!(is_safe_identifier("hEllo"), false);
1445        assert_eq!(is_safe_identifier("$hello"), false);
1446        assert_eq!(is_safe_identifier("hello!"), false);
1447        assert_eq!(is_safe_identifier("hello#"), false);
1448        assert_eq!(is_safe_identifier("he llo"), false);
1449        assert_eq!(is_safe_identifier(" hello"), false);
1450        assert_eq!(is_safe_identifier("he-llo"), false);
1451        assert_eq!(is_safe_identifier("hÉllo"), false);
1452        assert_eq!(is_safe_identifier("1337"), false);
1453        assert_eq!(is_safe_identifier("_HELLO"), false);
1454        assert_eq!(is_safe_identifier("HELLO"), false);
1455        assert_eq!(is_safe_identifier("HELLO$"), false);
1456        assert_eq!(is_safe_identifier("ÀBRACADABRA"), false);
1457
1458        for ident in RESERVED_KEYWORDS {
1459            assert_eq!(is_safe_identifier(ident), false);
1460        }
1461
1462        for ident in RESERVED_TYPE_FUNCTION_NAMES {
1463            assert_eq!(is_safe_identifier(ident), false);
1464        }
1465    }
1466
1467    #[test]
1468    fn search_path_pgbouncer_should_be_set_with_query() {
1469        let mut url = Url::parse(&CONN_STR).unwrap();
1470        url.query_pairs_mut().append_pair("schema", "hello");
1471        url.query_pairs_mut().append_pair("pgbouncer", "true");
1472
1473        let mut pg_url = PostgresUrl::new(url).unwrap();
1474        pg_url.set_flavour(PostgresFlavour::Postgres);
1475
1476        let config = pg_url.to_config();
1477
1478        // PGBouncer does not support the `search_path` connection parameter.
1479        // When `pgbouncer=true`, config.search_path should be None,
1480        // And the `search_path` should be set via a db query after connection.
1481        assert_eq!(config.get_search_path(), None);
1482    }
1483
1484    #[test]
1485    fn search_path_pg_should_be_set_with_param() {
1486        let mut url = Url::parse(&CONN_STR).unwrap();
1487        url.query_pairs_mut().append_pair("schema", "hello");
1488
1489        let mut pg_url = PostgresUrl::new(url).unwrap();
1490        pg_url.set_flavour(PostgresFlavour::Postgres);
1491
1492        let config = pg_url.to_config();
1493
1494        // Postgres supports setting the search_path via a connection parameter.
1495        assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned()));
1496    }
1497
1498    #[test]
1499    fn search_path_crdb_safe_ident_should_be_set_with_param() {
1500        let mut url = Url::parse(&CONN_STR).unwrap();
1501        url.query_pairs_mut().append_pair("schema", "hello");
1502
1503        let mut pg_url = PostgresUrl::new(url).unwrap();
1504        pg_url.set_flavour(PostgresFlavour::Cockroach);
1505
1506        let config = pg_url.to_config();
1507
1508        // CRDB supports setting the search_path via a connection parameter if the identifier is safe.
1509        assert_eq!(config.get_search_path(), Some(&"hello".to_owned()));
1510    }
1511
1512    #[test]
1513    fn search_path_crdb_unsafe_ident_should_be_set_with_query() {
1514        let mut url = Url::parse(&CONN_STR).unwrap();
1515        url.query_pairs_mut().append_pair("schema", "HeLLo");
1516
1517        let mut pg_url = PostgresUrl::new(url).unwrap();
1518        pg_url.set_flavour(PostgresFlavour::Cockroach);
1519
1520        let config = pg_url.to_config();
1521
1522        // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe.
1523        assert_eq!(config.get_search_path(), None);
1524    }
1525}