1mod conversion;
2mod error;
3
4use crate::{
5 ast::{Query, Value},
6 connector::{metrics, queryable::*, ResultSet, Transaction},
7 error::{Error, ErrorKind},
8 visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use futures::{future::FutureExt, lock::Mutex};
12use lru_cache::LruCache;
13use native_tls::{Certificate, Identity, TlsConnector};
14use percent_encoding::percent_decode;
15use postgres_native_tls::MakeTlsConnector;
16use std::{
17 borrow::{Borrow, Cow},
18 fmt::{Debug, Display},
19 fs,
20 future::Future,
21 sync::atomic::{AtomicBool, Ordering},
22 time::Duration,
23};
24use tokio_postgres::{
25 config::{ChannelBinding, SslMode},
26 Client, Config, Statement,
27};
28use url::Url;
29
30pub(crate) const DEFAULT_SCHEMA: &str = "public";
31
32#[cfg(feature = "expose-drivers")]
35pub use tokio_postgres;
36
37use super::IsolationLevel;
38
39#[derive(Clone)]
40struct Hidden<T>(T);
41
42impl<T> Debug for Hidden<T> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.write_str("<HIDDEN>")
45 }
46}
47
48struct PostgresClient(Client);
49
50impl Debug for PostgresClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.write_str("PostgresClient")
53 }
54}
55
56#[derive(Debug)]
58#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
59pub struct PostgreSql {
60 client: PostgresClient,
61 pg_bouncer: bool,
62 socket_timeout: Option<Duration>,
63 statement_cache: Mutex<LruCache<String, Statement>>,
64 is_healthy: AtomicBool,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
69pub enum SslAcceptMode {
70 Strict,
71 AcceptInvalidCerts,
72}
73
74#[derive(Debug, Clone)]
75#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
76pub struct SslParams {
77 certificate_file: Option<String>,
78 identity_file: Option<String>,
79 identity_password: Hidden<Option<String>>,
80 ssl_accept_mode: SslAcceptMode,
81}
82
83#[derive(Debug)]
84struct SslAuth {
85 certificate: Hidden<Option<Certificate>>,
86 identity: Hidden<Option<Identity>>,
87 ssl_accept_mode: SslAcceptMode,
88}
89
90impl Default for SslAuth {
91 fn default() -> Self {
92 Self { certificate: Hidden(None), identity: Hidden(None), ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts }
93 }
94}
95
96impl SslAuth {
97 fn certificate(&mut self, certificate: Certificate) -> &mut Self {
98 self.certificate = Hidden(Some(certificate));
99 self
100 }
101
102 fn identity(&mut self, identity: Identity) -> &mut Self {
103 self.identity = Hidden(Some(identity));
104 self
105 }
106
107 fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self {
108 self.ssl_accept_mode = mode;
109 self
110 }
111}
112
113impl SslParams {
114 async fn into_auth(self) -> crate::Result<SslAuth> {
115 let mut auth = SslAuth::default();
116 auth.accept_mode(self.ssl_accept_mode);
117
118 if let Some(ref cert_file) = self.certificate_file {
119 let cert = fs::read(cert_file).map_err(|err| {
120 Error::builder(ErrorKind::TlsError { message: format!("cert file not found ({err})") }).build()
121 })?;
122
123 auth.certificate(Certificate::from_pem(&cert)?);
124 }
125
126 if let Some(ref identity_file) = self.identity_file {
127 let db = fs::read(identity_file).map_err(|err| {
128 Error::builder(ErrorKind::TlsError { message: format!("identity file not found ({err})") }).build()
129 })?;
130 let password = self.identity_password.0.as_deref().unwrap_or("");
131 let identity = Identity::from_pkcs12(&db, password)?;
132
133 auth.identity(identity);
134 }
135
136 Ok(auth)
137 }
138}
139
140#[derive(Debug, Clone, Copy)]
141pub enum PostgresFlavour {
142 Postgres,
143 Cockroach,
144 Unknown,
145}
146
147impl PostgresFlavour {
148 fn is_postgres(&self) -> bool {
152 matches!(self, Self::Postgres)
153 }
154
155 fn is_cockroach(&self) -> bool {
159 matches!(self, Self::Cockroach)
160 }
161
162 fn is_unknown(&self) -> bool {
166 matches!(self, Self::Unknown)
167 }
168}
169
170#[derive(Debug, Clone)]
173#[cfg_attr(feature = "docs", doc(cfg(feature = "postgresql")))]
174pub struct PostgresUrl {
175 url: Url,
176 query_params: PostgresUrlQueryParams,
177 flavour: PostgresFlavour,
178}
179
180impl PostgresUrl {
181 pub fn new(url: Url) -> Result<Self, Error> {
184 let query_params = Self::parse_query_params(&url)?;
185
186 Ok(Self { url, query_params, flavour: PostgresFlavour::Unknown })
187 }
188
189 pub fn url(&self) -> &Url {
191 &self.url
192 }
193
194 pub fn username(&self) -> Cow<str> {
196 match percent_decode(self.url.username().as_bytes()).decode_utf8() {
197 Ok(username) => username,
198 Err(_) => {
199 tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
200
201 self.url.username().into()
202 }
203 }
204 }
205
206 pub fn host(&self) -> &str {
212 match (self.query_params.host.as_ref(), self.url.host_str()) {
213 (Some(host), _) => host.as_str(),
214 (None, Some("")) => "localhost",
215 (None, None) => "localhost",
216 (None, Some(host)) => host,
217 }
218 }
219
220 pub fn dbname(&self) -> &str {
222 match self.url.path_segments() {
223 Some(mut segments) => segments.next().unwrap_or("postgres"),
224 None => "postgres",
225 }
226 }
227
228 pub fn password(&self) -> Cow<str> {
230 match self.url.password().and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) {
231 Some(password) => password,
232 None => self.url.password().unwrap_or("").into(),
233 }
234 }
235
236 pub fn port(&self) -> u16 {
238 self.url.port().unwrap_or(5432)
239 }
240
241 pub fn schema(&self) -> &str {
243 self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA)
244 }
245
246 pub fn pg_bouncer(&self) -> bool {
248 self.query_params.pg_bouncer
249 }
250
251 pub fn connect_timeout(&self) -> Option<Duration> {
253 self.query_params.connect_timeout
254 }
255
256 pub fn pool_timeout(&self) -> Option<Duration> {
258 self.query_params.pool_timeout
259 }
260
261 pub fn socket_timeout(&self) -> Option<Duration> {
263 self.query_params.socket_timeout
264 }
265
266 pub fn max_connection_lifetime(&self) -> Option<Duration> {
268 self.query_params.max_connection_lifetime
269 }
270
271 pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
273 self.query_params.max_idle_connection_lifetime
274 }
275
276 pub fn application_name(&self) -> Option<&str> {
278 self.query_params.application_name.as_deref()
279 }
280
281 pub fn channel_binding(&self) -> ChannelBinding {
282 self.query_params.channel_binding
283 }
284
285 pub(crate) fn cache(&self) -> LruCache<String, Statement> {
286 if self.query_params.pg_bouncer {
287 LruCache::new(0)
288 } else {
289 LruCache::new(self.query_params.statement_cache_size)
290 }
291 }
292
293 pub(crate) fn options(&self) -> Option<&str> {
294 self.query_params.options.as_deref()
295 }
296
297 pub fn set_flavour(&mut self, flavour: PostgresFlavour) {
305 self.flavour = flavour;
306 }
307
308 fn parse_query_params(url: &Url) -> Result<PostgresUrlQueryParams, Error> {
309 let mut connection_limit = None;
310 let mut schema = None;
311 let mut certificate_file = None;
312 let mut identity_file = None;
313 let mut identity_password = None;
314 let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts;
315 let mut ssl_mode = SslMode::Prefer;
316 let mut host = None;
317 let mut application_name = None;
318 let mut channel_binding = ChannelBinding::Prefer;
319 let mut socket_timeout = None;
320 let mut connect_timeout = Some(Duration::from_secs(5));
321 let mut pool_timeout = Some(Duration::from_secs(10));
322 let mut pg_bouncer = false;
323 let mut statement_cache_size = 100;
324 let mut max_connection_lifetime = None;
325 let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
326 let mut options = None;
327
328 for (k, v) in url.query_pairs() {
329 match k.as_ref() {
330 "pgbouncer" => {
331 pg_bouncer =
332 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
333 }
334 "sslmode" => {
335 match v.as_ref() {
336 "disable" => ssl_mode = SslMode::Disable,
337 "prefer" => ssl_mode = SslMode::Prefer,
338 "require" => ssl_mode = SslMode::Require,
339 _ => {
340 tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v);
341 }
342 };
343 }
344 "sslcert" => {
345 certificate_file = Some(v.to_string());
346 }
347 "sslidentity" => {
348 identity_file = Some(v.to_string());
349 }
350 "sslpassword" => {
351 identity_password = Some(v.to_string());
352 }
353 "statement_cache_size" => {
354 statement_cache_size =
355 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
356 }
357 "sslaccept" => {
358 match v.as_ref() {
359 "strict" => {
360 ssl_accept_mode = SslAcceptMode::Strict;
361 }
362 "accept_invalid_certs" => {
363 ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts;
364 }
365 _ => {
366 tracing::debug!(
367 message = "Unsupported SSL accept mode, defaulting to `strict`",
368 mode = &*v
369 );
370
371 ssl_accept_mode = SslAcceptMode::Strict;
372 }
373 };
374 }
375 "schema" => {
376 schema = Some(v.to_string());
377 }
378 "connection_limit" => {
379 let as_int: usize =
380 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
381 connection_limit = Some(as_int);
382 }
383 "host" => {
384 host = Some(v.to_string());
385 }
386 "socket_timeout" => {
387 let as_int =
388 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
389 socket_timeout = Some(Duration::from_secs(as_int));
390 }
391 "connect_timeout" => {
392 let as_int =
393 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
394
395 if as_int == 0 {
396 connect_timeout = None;
397 } else {
398 connect_timeout = Some(Duration::from_secs(as_int));
399 }
400 }
401 "pool_timeout" => {
402 let as_int =
403 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
404
405 if as_int == 0 {
406 pool_timeout = None;
407 } else {
408 pool_timeout = Some(Duration::from_secs(as_int));
409 }
410 }
411 "max_connection_lifetime" => {
412 let as_int =
413 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
414
415 if as_int == 0 {
416 max_connection_lifetime = None;
417 } else {
418 max_connection_lifetime = Some(Duration::from_secs(as_int));
419 }
420 }
421 "max_idle_connection_lifetime" => {
422 let as_int =
423 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
424
425 if as_int == 0 {
426 max_idle_connection_lifetime = None;
427 } else {
428 max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
429 }
430 }
431 "application_name" => {
432 application_name = Some(v.to_string());
433 }
434 "channel_binding" => {
435 match v.as_ref() {
436 "disable" => channel_binding = ChannelBinding::Disable,
437 "prefer" => channel_binding = ChannelBinding::Prefer,
438 "require" => channel_binding = ChannelBinding::Require,
439 _ => {
440 tracing::debug!(
441 message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`",
442 channel_binding = &*v
443 );
444 }
445 };
446 }
447 "options" => {
448 options = Some(v.to_string());
449 }
450 _ => {
451 tracing::trace!(message = "Discarding connection string param", param = &*k);
452 }
453 };
454 }
455
456 Ok(PostgresUrlQueryParams {
457 ssl_params: SslParams {
458 certificate_file,
459 identity_file,
460 ssl_accept_mode,
461 identity_password: Hidden(identity_password),
462 },
463 connection_limit,
464 schema,
465 ssl_mode,
466 host,
467 connect_timeout,
468 pool_timeout,
469 socket_timeout,
470 pg_bouncer,
471 statement_cache_size,
472 max_connection_lifetime,
473 max_idle_connection_lifetime,
474 application_name,
475 channel_binding,
476 options,
477 })
478 }
479
480 pub(crate) fn ssl_params(&self) -> &SslParams {
481 &self.query_params.ssl_params
482 }
483
484 #[cfg(feature = "pooled")]
485 pub(crate) fn connection_limit(&self) -> Option<usize> {
486 self.query_params.connection_limit
487 }
488
489 fn set_search_path(&self, config: &mut Config) {
493 if self.query_params.pg_bouncer {
496 return;
497 }
498
499 if let Some(schema) = &self.query_params.schema {
500 if self.flavour().is_cockroach() && is_safe_identifier(schema) {
501 config.search_path(CockroachSearchPath(schema).to_string());
502 }
503
504 if self.flavour().is_postgres() {
505 config.search_path(PostgresSearchPath(schema).to_string());
506 }
507 }
508 }
509
510 pub(crate) fn to_config(&self) -> Config {
511 let mut config = Config::new();
512
513 config.user(self.username().borrow());
514 config.password(self.password().borrow() as &str);
515 config.host(self.host());
516 config.port(self.port());
517 config.dbname(self.dbname());
518 config.pgbouncer_mode(self.query_params.pg_bouncer);
519
520 if let Some(options) = self.options() {
521 config.options(options);
522 }
523
524 if let Some(application_name) = self.application_name() {
525 config.application_name(application_name);
526 }
527
528 if let Some(connect_timeout) = self.query_params.connect_timeout {
529 config.connect_timeout(connect_timeout);
530 }
531
532 self.set_search_path(&mut config);
533
534 config.ssl_mode(self.query_params.ssl_mode);
535
536 config.channel_binding(self.query_params.channel_binding);
537
538 config
539 }
540
541 pub fn flavour(&self) -> PostgresFlavour {
542 self.flavour
543 }
544}
545
546#[derive(Debug, Clone)]
547pub(crate) struct PostgresUrlQueryParams {
548 ssl_params: SslParams,
549 connection_limit: Option<usize>,
550 schema: Option<String>,
551 ssl_mode: SslMode,
552 pg_bouncer: bool,
553 host: Option<String>,
554 socket_timeout: Option<Duration>,
555 connect_timeout: Option<Duration>,
556 pool_timeout: Option<Duration>,
557 statement_cache_size: usize,
558 max_connection_lifetime: Option<Duration>,
559 max_idle_connection_lifetime: Option<Duration>,
560 application_name: Option<String>,
561 channel_binding: ChannelBinding,
562 options: Option<String>,
563}
564
565impl PostgreSql {
566 pub async fn new(url: PostgresUrl) -> crate::Result<Self> {
568 let config = url.to_config();
569
570 let mut tls_builder = TlsConnector::builder();
571
572 {
573 let ssl_params = url.ssl_params();
574 let auth = ssl_params.to_owned().into_auth().await?;
575
576 if let Some(certificate) = auth.certificate.0 {
577 tls_builder.add_root_certificate(certificate);
578 }
579
580 tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts);
581
582 if let Some(identity) = auth.identity.0 {
583 tls_builder.identity(identity);
584 }
585 }
586
587 let tls = MakeTlsConnector::new(tls_builder.build()?);
588 let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?;
589
590 tokio::spawn(conn.map(|r| match r {
591 Ok(_) => (),
592 Err(e) => {
593 tracing::error!("Error in PostgreSQL connection: {:?}", e);
594 }
595 }));
596
597 if let Some(schema) = &url.query_params.schema {
602 if url.query_params.pg_bouncer
605 || url.flavour().is_unknown()
606 || (url.flavour().is_cockroach() && !is_safe_identifier(schema))
607 {
608 let session_variables = format!(
609 r##"{set_search_path}"##,
610 set_search_path = SetSearchPath(url.query_params.schema.as_deref())
611 );
612
613 client.simple_query(session_variables.as_str()).await?;
614 }
615 }
616
617 Ok(Self {
618 client: PostgresClient(client),
619 socket_timeout: url.query_params.socket_timeout,
620 pg_bouncer: url.query_params.pg_bouncer,
621 statement_cache: Mutex::new(url.cache()),
622 is_healthy: AtomicBool::new(true),
623 })
624 }
625
626 #[cfg(feature = "expose-drivers")]
630 pub fn client(&self) -> &tokio_postgres::Client {
631 &self.client.0
632 }
633
634 async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<Statement> {
635 let mut cache = self.statement_cache.lock().await;
636 let capacity = cache.capacity();
637 let stored = cache.len();
638
639 match cache.get_mut(sql) {
640 Some(stmt) => {
641 tracing::trace!(message = "CACHE HIT!", query = sql, capacity = capacity, stored = stored,);
642
643 Ok(stmt.clone()) }
645 None => {
646 tracing::trace!(message = "CACHE MISS!", query = sql, capacity = capacity, stored = stored,);
647
648 let param_types = conversion::params_to_types(params);
649 let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?;
650
651 cache.insert(sql.to_string(), stmt.clone());
652
653 Ok(stmt)
654 }
655 }
656 }
657
658 async fn perform_io<F, T>(&self, fut: F) -> crate::Result<T>
659 where
660 F: Future<Output = Result<T, tokio_postgres::Error>>,
661 {
662 match super::timeout::socket(self.socket_timeout, fut).await {
663 Err(e) if e.is_closed() => {
664 self.is_healthy.store(false, Ordering::SeqCst);
665 Err(e)
666 }
667 res => res,
668 }
669 }
670
671 fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> {
672 if params.len() > i16::MAX as usize {
673 let kind = ErrorKind::QueryInvalidInput(format!(
676 "too many bind variables in prepared statement, expected maximum of {}, received {}",
677 i16::MAX,
678 params.len()
679 ));
680 Err(Error::builder(kind).build())
681 } else {
682 Ok(())
683 }
684 }
685}
686
687struct CockroachSearchPath<'a>(&'a str);
689
690impl Display for CockroachSearchPath<'_> {
691 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692 f.write_str(self.0)
693 }
694}
695
696struct PostgresSearchPath<'a>(&'a str);
698
699impl Display for PostgresSearchPath<'_> {
700 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701 f.write_str("\"")?;
702 f.write_str(self.0)?;
703 f.write_str("\"")?;
704
705 Ok(())
706 }
707}
708
709struct SetSearchPath<'a>(Option<&'a str>);
711
712impl Display for SetSearchPath<'_> {
713 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714 if let Some(schema) = self.0 {
715 f.write_str("SET search_path = \"")?;
716 f.write_str(schema)?;
717 f.write_str("\";\n")?;
718 }
719
720 Ok(())
721 }
722}
723
724impl TransactionCapable for PostgreSql {}
725
726#[async_trait]
727impl Queryable for PostgreSql {
728 async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
729 let (sql, params) = visitor::Postgres::build(q)?;
730
731 self.query_raw(sql.as_str(), ¶ms[..]).await
732 }
733
734 async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
735 self.check_bind_variables_len(params)?;
736
737 metrics::query("postgres.query_raw", sql, params, move || async move {
738 let stmt = self.fetch_cached(sql, &[]).await?;
739
740 if stmt.params().len() != params.len() {
741 let kind =
742 ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
743
744 return Err(Error::builder(kind).build());
745 }
746
747 let rows = self.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())).await?;
748
749 let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
750
751 for row in rows {
752 result.rows.push(row.get_result_row()?);
753 }
754
755 Ok(result)
756 })
757 .await
758 }
759
760 async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
761 self.check_bind_variables_len(params)?;
762
763 metrics::query("postgres.query_raw", sql, params, move || async move {
764 let stmt = self.fetch_cached(sql, params).await?;
765
766 if stmt.params().len() != params.len() {
767 let kind =
768 ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
769
770 return Err(Error::builder(kind).build());
771 }
772
773 let rows = self.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())).await?;
774
775 let mut result = ResultSet::new(stmt.to_column_names(), Vec::new());
776
777 for row in rows {
778 result.rows.push(row.get_result_row()?);
779 }
780
781 Ok(result)
782 })
783 .await
784 }
785
786 async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
787 let (sql, params) = visitor::Postgres::build(q)?;
788
789 self.execute_raw(sql.as_str(), ¶ms[..]).await
790 }
791
792 async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
793 self.check_bind_variables_len(params)?;
794
795 metrics::query("postgres.execute_raw", sql, params, move || async move {
796 let stmt = self.fetch_cached(sql, &[]).await?;
797
798 if stmt.params().len() != params.len() {
799 let kind =
800 ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
801
802 return Err(Error::builder(kind).build());
803 }
804
805 let changes =
806 self.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())).await?;
807
808 Ok(changes)
809 })
810 .await
811 }
812
813 async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
814 self.check_bind_variables_len(params)?;
815
816 metrics::query("postgres.execute_raw", sql, params, move || async move {
817 let stmt = self.fetch_cached(sql, params).await?;
818
819 if stmt.params().len() != params.len() {
820 let kind =
821 ErrorKind::IncorrectNumberOfParameters { expected: stmt.params().len(), actual: params.len() };
822
823 return Err(Error::builder(kind).build());
824 }
825
826 let changes =
827 self.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())).await?;
828
829 Ok(changes)
830 })
831 .await
832 }
833
834 async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
835 metrics::query("postgres.raw_cmd", cmd, &[], move || async move {
836 self.perform_io(self.client.0.simple_query(cmd)).await?;
837 Ok(())
838 })
839 .await
840 }
841
842 async fn version(&self) -> crate::Result<Option<String>> {
843 let query = r#"SELECT version()"#;
844 let rows = self.query_raw(query, &[]).await?;
845
846 let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
847
848 Ok(version_string)
849 }
850
851 fn is_healthy(&self) -> bool {
852 self.is_healthy.load(Ordering::SeqCst)
853 }
854
855 async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> {
856 if self.pg_bouncer {
857 tx.raw_cmd("DEALLOCATE ALL").await
858 } else {
859 Ok(())
860 }
861 }
862
863 async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
864 if matches!(isolation_level, IsolationLevel::Snapshot) {
865 return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
866 }
867
868 self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
869
870 Ok(())
871 }
872
873 fn requires_isolation_first(&self) -> bool {
874 false
875 }
876}
877
878const RESERVED_KEYWORDS: [&str; 79] = [
881 "all",
882 "analyse",
883 "analyze",
884 "and",
885 "any",
886 "array",
887 "as",
888 "asc",
889 "asymmetric",
890 "both",
891 "case",
892 "cast",
893 "check",
894 "collate",
895 "column",
896 "concurrently",
897 "constraint",
898 "create",
899 "current_catalog",
900 "current_date",
901 "current_role",
902 "current_schema",
903 "current_time",
904 "current_timestamp",
905 "current_user",
906 "default",
907 "deferrable",
908 "desc",
909 "distinct",
910 "do",
911 "else",
912 "end",
913 "except",
914 "false",
915 "fetch",
916 "for",
917 "foreign",
918 "from",
919 "grant",
920 "group",
921 "having",
922 "in",
923 "initially",
924 "intersect",
925 "into",
926 "lateral",
927 "leading",
928 "limit",
929 "localtime",
930 "localtimestamp",
931 "not",
932 "null",
933 "offset",
934 "on",
935 "only",
936 "or",
937 "order",
938 "placing",
939 "primary",
940 "references",
941 "returning",
942 "select",
943 "session_user",
944 "some",
945 "symmetric",
946 "table",
947 "then",
948 "to",
949 "trailing",
950 "true",
951 "union",
952 "unique",
953 "user",
954 "using",
955 "variadic",
956 "when",
957 "where",
958 "window",
959 "with",
960];
961
962const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [
965 "authorization",
966 "collation",
967 "cross",
968 "full",
969 "ilike",
970 "inner",
971 "is",
972 "isnull",
973 "join",
974 "left",
975 "like",
976 "natural",
977 "none",
978 "notnull",
979 "outer",
980 "overlaps",
981 "right",
982 "similar",
983];
984
985fn is_safe_identifier(ident: &str) -> bool {
992 if ident.is_empty() {
993 return false;
994 }
995
996 if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() {
998 return false;
999 }
1000
1001 let mut chars = ident.chars();
1002
1003 let first = chars.next().unwrap();
1004
1005 if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' {
1007 return false;
1008 }
1009
1010 for c in chars {
1011 if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' {
1013 return false;
1014 }
1015 }
1016
1017 true
1018}
1019
1020#[cfg(test)]
1021mod tests {
1022 use super::*;
1023 use crate::tests::test_api::postgres::CONN_STR;
1024 use crate::tests::test_api::CRDB_CONN_STR;
1025 use crate::{connector::Queryable, error::*, single::Sqlint};
1026 use url::Url;
1027
1028 #[test]
1029 fn should_parse_socket_url() {
1030 let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap();
1031 assert_eq!("dbname", url.dbname());
1032 assert_eq!("/var/run/psql.sock", url.host());
1033 }
1034
1035 #[test]
1036 fn should_parse_escaped_url() {
1037 let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap();
1038 assert_eq!("dbname", url.dbname());
1039 assert_eq!("/var/run/postgresql", url.host());
1040 }
1041
1042 #[test]
1043 fn should_allow_changing_of_cache_size() {
1044 let url =
1045 PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap();
1046 assert_eq!(420, url.cache().capacity());
1047 }
1048
1049 #[test]
1050 fn should_have_default_cache_size() {
1051 let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap();
1052 assert_eq!(100, url.cache().capacity());
1053 }
1054
1055 #[test]
1056 fn should_have_application_name() {
1057 let url =
1058 PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap();
1059 assert_eq!(Some("test"), url.application_name());
1060 }
1061
1062 #[test]
1063 fn should_have_channel_binding() {
1064 let url =
1065 PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap();
1066 assert_eq!(ChannelBinding::Require, url.channel_binding());
1067 }
1068
1069 #[test]
1070 fn should_have_default_channel_binding() {
1071 let url =
1072 PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap();
1073 assert_eq!(ChannelBinding::Prefer, url.channel_binding());
1074
1075 let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap();
1076 assert_eq!(ChannelBinding::Prefer, url.channel_binding());
1077 }
1078
1079 #[test]
1080 fn should_not_enable_caching_with_pgbouncer() {
1081 let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap();
1082 assert_eq!(0, url.cache().capacity());
1083 }
1084
1085 #[test]
1086 fn should_parse_default_host() {
1087 let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap();
1088 assert_eq!("dbname", url.dbname());
1089 assert_eq!("localhost", url.host());
1090 }
1091
1092 #[test]
1093 fn should_handle_options_field() {
1094 let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap())
1095 .unwrap();
1096
1097 assert_eq!("--cluster=my_cluster", url.options().unwrap());
1098 }
1099
1100 #[tokio::test]
1101 async fn test_custom_search_path_pg() {
1102 async fn test_path(schema_name: &str) -> Option<String> {
1103 let mut url = Url::parse(&CONN_STR).unwrap();
1104 url.query_pairs_mut().append_pair("schema", schema_name);
1105
1106 let mut pg_url = PostgresUrl::new(url).unwrap();
1107 pg_url.set_flavour(PostgresFlavour::Postgres);
1108
1109 let client = PostgreSql::new(pg_url).await.unwrap();
1110
1111 let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1112 let row = result_set.first().unwrap();
1113
1114 row[0].to_string()
1115 }
1116
1117 assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\""));
1119 assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\""));
1120 assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1121 assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\""));
1122 assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1123 assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1124 assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1125
1126 assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1128 assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1129 assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1130 assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1131 assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1132 assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1133 assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1134 assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1135 assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1136 assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1137 assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1138 assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1139 assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1140 assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1141
1142 for ident in RESERVED_KEYWORDS {
1143 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1144 }
1145
1146 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1147 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1148 }
1149 }
1150
1151 #[tokio::test]
1152 async fn test_custom_search_path_pg_pgbouncer() {
1153 async fn test_path(schema_name: &str) -> Option<String> {
1154 let mut url = Url::parse(&CONN_STR).unwrap();
1155 url.query_pairs_mut().append_pair("schema", schema_name);
1156 url.query_pairs_mut().append_pair("pbbouncer", "true");
1157
1158 let mut pg_url = PostgresUrl::new(url).unwrap();
1159 pg_url.set_flavour(PostgresFlavour::Postgres);
1160
1161 let client = PostgreSql::new(pg_url).await.unwrap();
1162
1163 let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1164 let row = result_set.first().unwrap();
1165
1166 row[0].to_string()
1167 }
1168
1169 assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\""));
1171 assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\""));
1172 assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1173 assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\""));
1174 assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1175 assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1176 assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1177
1178 assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1180 assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1181 assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1182 assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1183 assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1184 assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1185 assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1186 assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1187 assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1188 assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1189 assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1190 assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1191 assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1192 assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1193
1194 for ident in RESERVED_KEYWORDS {
1195 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1196 }
1197
1198 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1199 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1200 }
1201 }
1202
1203 #[tokio::test]
1204 async fn test_custom_search_path_crdb() {
1205 async fn test_path(schema_name: &str) -> Option<String> {
1206 let mut url = Url::parse(&CRDB_CONN_STR).unwrap();
1207 url.query_pairs_mut().append_pair("schema", schema_name);
1208
1209 let mut pg_url = PostgresUrl::new(url).unwrap();
1210 pg_url.set_flavour(PostgresFlavour::Cockroach);
1211
1212 let client = PostgreSql::new(pg_url).await.unwrap();
1213
1214 let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1215 let row = result_set.first().unwrap();
1216
1217 row[0].to_string()
1218 }
1219
1220 assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1222 assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1223 assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra"));
1224 assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1225 assert_eq!(test_path("héllo").await.as_deref(), Some("héllo"));
1226 assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$"));
1227 assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$"));
1228
1229 assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1231 assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1232 assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1233 assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1234 assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1235 assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1236 assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1237 assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1238 assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1239 assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1240 assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1241 assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1242 assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1243 assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1244
1245 for ident in RESERVED_KEYWORDS {
1246 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1247 }
1248
1249 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1250 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1251 }
1252 }
1253
1254 #[tokio::test]
1255 async fn test_custom_search_path_unknown_pg() {
1256 async fn test_path(schema_name: &str) -> Option<String> {
1257 let mut url = Url::parse(&CONN_STR).unwrap();
1258 url.query_pairs_mut().append_pair("schema", schema_name);
1259
1260 let mut pg_url = PostgresUrl::new(url).unwrap();
1261 pg_url.set_flavour(PostgresFlavour::Unknown);
1262
1263 let client = PostgreSql::new(pg_url).await.unwrap();
1264
1265 let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1266 let row = result_set.first().unwrap();
1267
1268 row[0].to_string()
1269 }
1270
1271 assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1273 assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1274 assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1275 assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1276 assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1277 assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1278 assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1279
1280 assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1282 assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1283 assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1284 assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1285 assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1286 assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1287 assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1288 assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1289 assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1290 assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1291 assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1292 assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1293 assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1294 assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1295
1296 for ident in RESERVED_KEYWORDS {
1297 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1298 }
1299
1300 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1301 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1302 }
1303 }
1304
1305 #[tokio::test]
1306 async fn test_custom_search_path_unknown_crdb() {
1307 async fn test_path(schema_name: &str) -> Option<String> {
1308 let mut url = Url::parse(&CONN_STR).unwrap();
1309 url.query_pairs_mut().append_pair("schema", schema_name);
1310
1311 let mut pg_url = PostgresUrl::new(url).unwrap();
1312 pg_url.set_flavour(PostgresFlavour::Unknown);
1313
1314 let client = PostgreSql::new(pg_url).await.unwrap();
1315
1316 let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap();
1317 let row = result_set.first().unwrap();
1318
1319 row[0].to_string()
1320 }
1321
1322 assert_eq!(test_path("hello").await.as_deref(), Some("hello"));
1324 assert_eq!(test_path("_hello").await.as_deref(), Some("_hello"));
1325 assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\""));
1326 assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0"));
1327 assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\""));
1328 assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\""));
1329 assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\""));
1330
1331 assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\""));
1333 assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\""));
1334 assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\""));
1335 assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\""));
1336 assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\""));
1337 assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\""));
1338 assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\""));
1339 assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\""));
1340 assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\""));
1341 assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\""));
1342 assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\""));
1343 assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\""));
1344 assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\""));
1345 assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\""));
1346
1347 for ident in RESERVED_KEYWORDS {
1348 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1349 }
1350
1351 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1352 assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str()));
1353 }
1354 }
1355
1356 #[tokio::test]
1357 async fn should_map_nonexisting_database_error() {
1358 let mut url = Url::parse(&CONN_STR).unwrap();
1359 url.set_path("/this_does_not_exist");
1360
1361 let res = Sqlint::new(url.as_str()).await;
1362
1363 assert!(res.is_err());
1364
1365 match res {
1366 Ok(_) => unreachable!(),
1367 Err(e) => match e.kind() {
1368 ErrorKind::DatabaseDoesNotExist { db_name } => {
1369 assert_eq!(Some("3D000"), e.original_code());
1370 assert_eq!(Some("database \"this_does_not_exist\" does not exist"), e.original_message());
1371 assert_eq!(&Name::available("this_does_not_exist"), db_name)
1372 }
1373 kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind),
1374 },
1375 }
1376 }
1377
1378 #[tokio::test]
1379 async fn should_map_wrong_credentials_error() {
1380 let mut url = Url::parse(&CONN_STR).unwrap();
1381 url.set_username("WRONG").unwrap();
1382
1383 let res = Sqlint::new(url.as_str()).await;
1384 assert!(res.is_err());
1385
1386 let err = res.unwrap_err();
1387 assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
1388 }
1389
1390 #[tokio::test]
1391 async fn should_map_tls_errors() {
1392 let mut url = Url::parse(&CONN_STR).expect("parsing url");
1393 url.set_query(Some("sslmode=require&sslaccept=strict"));
1394
1395 let res = Sqlint::new(url.as_str()).await;
1396
1397 assert!(res.is_err());
1398
1399 match res {
1400 Ok(_) => unreachable!(),
1401 Err(e) => match e.kind() {
1402 ErrorKind::TlsError { .. } => (),
1403 other => panic!("{:#?}", other),
1404 },
1405 }
1406 }
1407
1408 #[tokio::test]
1409 async fn should_map_incorrect_parameters_error() {
1410 let url = Url::parse(&CONN_STR).unwrap();
1411 let conn = Sqlint::new(url.as_str()).await.unwrap();
1412
1413 let res = conn.query_raw("SELECT $1", &[Value::integer(1), Value::integer(2)]).await;
1414
1415 assert!(res.is_err());
1416
1417 match res {
1418 Ok(_) => unreachable!(),
1419 Err(e) => match e.kind() {
1420 ErrorKind::IncorrectNumberOfParameters { expected, actual } => {
1421 assert_eq!(1, *expected);
1422 assert_eq!(2, *actual);
1423 }
1424 other => panic!("{:#?}", other),
1425 },
1426 }
1427 }
1428
1429 #[test]
1430 fn test_safe_ident() {
1431 assert_eq!(is_safe_identifier("hello"), true);
1433 assert_eq!(is_safe_identifier("_hello"), true);
1434 assert_eq!(is_safe_identifier("àbracadabra"), true);
1435 assert_eq!(is_safe_identifier("h3ll0"), true);
1436 assert_eq!(is_safe_identifier("héllo"), true);
1437 assert_eq!(is_safe_identifier("héll0$"), true);
1438 assert_eq!(is_safe_identifier("héll_0$"), true);
1439 assert_eq!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m"), true);
1440
1441 assert_eq!(is_safe_identifier(""), false);
1443 assert_eq!(is_safe_identifier("Hello"), false);
1444 assert_eq!(is_safe_identifier("hEllo"), false);
1445 assert_eq!(is_safe_identifier("$hello"), false);
1446 assert_eq!(is_safe_identifier("hello!"), false);
1447 assert_eq!(is_safe_identifier("hello#"), false);
1448 assert_eq!(is_safe_identifier("he llo"), false);
1449 assert_eq!(is_safe_identifier(" hello"), false);
1450 assert_eq!(is_safe_identifier("he-llo"), false);
1451 assert_eq!(is_safe_identifier("hÉllo"), false);
1452 assert_eq!(is_safe_identifier("1337"), false);
1453 assert_eq!(is_safe_identifier("_HELLO"), false);
1454 assert_eq!(is_safe_identifier("HELLO"), false);
1455 assert_eq!(is_safe_identifier("HELLO$"), false);
1456 assert_eq!(is_safe_identifier("ÀBRACADABRA"), false);
1457
1458 for ident in RESERVED_KEYWORDS {
1459 assert_eq!(is_safe_identifier(ident), false);
1460 }
1461
1462 for ident in RESERVED_TYPE_FUNCTION_NAMES {
1463 assert_eq!(is_safe_identifier(ident), false);
1464 }
1465 }
1466
1467 #[test]
1468 fn search_path_pgbouncer_should_be_set_with_query() {
1469 let mut url = Url::parse(&CONN_STR).unwrap();
1470 url.query_pairs_mut().append_pair("schema", "hello");
1471 url.query_pairs_mut().append_pair("pgbouncer", "true");
1472
1473 let mut pg_url = PostgresUrl::new(url).unwrap();
1474 pg_url.set_flavour(PostgresFlavour::Postgres);
1475
1476 let config = pg_url.to_config();
1477
1478 assert_eq!(config.get_search_path(), None);
1482 }
1483
1484 #[test]
1485 fn search_path_pg_should_be_set_with_param() {
1486 let mut url = Url::parse(&CONN_STR).unwrap();
1487 url.query_pairs_mut().append_pair("schema", "hello");
1488
1489 let mut pg_url = PostgresUrl::new(url).unwrap();
1490 pg_url.set_flavour(PostgresFlavour::Postgres);
1491
1492 let config = pg_url.to_config();
1493
1494 assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned()));
1496 }
1497
1498 #[test]
1499 fn search_path_crdb_safe_ident_should_be_set_with_param() {
1500 let mut url = Url::parse(&CONN_STR).unwrap();
1501 url.query_pairs_mut().append_pair("schema", "hello");
1502
1503 let mut pg_url = PostgresUrl::new(url).unwrap();
1504 pg_url.set_flavour(PostgresFlavour::Cockroach);
1505
1506 let config = pg_url.to_config();
1507
1508 assert_eq!(config.get_search_path(), Some(&"hello".to_owned()));
1510 }
1511
1512 #[test]
1513 fn search_path_crdb_unsafe_ident_should_be_set_with_query() {
1514 let mut url = Url::parse(&CONN_STR).unwrap();
1515 url.query_pairs_mut().append_pair("schema", "HeLLo");
1516
1517 let mut pg_url = PostgresUrl::new(url).unwrap();
1518 pg_url.set_flavour(PostgresFlavour::Cockroach);
1519
1520 let config = pg_url.to_config();
1521
1522 assert_eq!(config.get_search_path(), None);
1524 }
1525}