1use crate::{
2 PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3 util::{
4 stream_postgres_row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result,
5 },
6};
7use async_stream::try_stream;
8use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
9use postgres_openssl::MakeTlsConnector;
10use std::{borrow::Cow, env, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
11use tank_core::{
12 Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, Transaction,
13 future::Either,
14 stream::{Stream, StreamExt, TryStreamExt},
15 truncate_long,
16};
17use tokio::{spawn, task::JoinHandle};
18use tokio_postgres::NoTls;
19use url::Url;
20use urlencoding::decode;
21
22#[derive(Debug)]
23pub struct PostgresConnection {
24 pub(crate) client: tokio_postgres::Client,
25 pub(crate) handle: JoinHandle<()>,
26 pub(crate) _transaction: bool,
27}
28
29impl Executor for PostgresConnection {
30 type Driver = PostgresDriver;
31
32 fn driver(&self) -> &Self::Driver {
33 &PostgresDriver {}
34 }
35
36 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
37 let sql = sql.trim_end().trim_end_matches(';');
38 Ok(
39 PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
40 let e = Error::new(e).context(format!(
41 "While preparing the query:\n{}",
42 truncate_long!(sql)
43 ));
44 log::error!("{:#}", e);
45 e
46 })?)
47 .into(),
48 )
49 }
50
51 fn run(
52 &mut self,
53 query: Query<Self::Driver>,
54 ) -> impl Stream<Item = Result<QueryResult>> + Send {
55 let context = Arc::new(format!("While running the query:\n{}", query));
56 match query {
57 Query::Raw(sql) => {
58 Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
59 async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
60 ))
61 }
62 Query::Prepared(..) => Either::Right(try_stream! {
63 let mut transaction = self.begin().await?;
64 {
65 let mut stream = pin!(transaction.run(query));
66 while let Some(value) = stream.next().await.transpose()? {
67 yield value;
68 }
69 }
70 transaction.commit().await?;
71 }),
72 }
73 .map_err(move |e: Error| {
74 let e = e.context(context.clone());
75 log::error!("{:#}", e);
76 e
77 })
78 }
79
80 fn fetch<'s>(
81 &'s mut self,
82 query: Query<Self::Driver>,
83 ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
84 let context = Arc::new(format!("While fetching the query:\n{}", query));
85 match query {
86 Query::Raw(sql) => Either::Left(stream_postgres_row_to_tank_row(async move || {
87 self.client
88 .query_raw(&sql, Vec::<ValueWrap>::new())
89 .await
90 .map_err(|e| {
91 let e = Error::new(e).context(context.clone());
92 log::error!("{:#}", e);
93 e
94 })
95 })),
96 Query::Prepared(..) => Either::Right(
97 try_stream! {
98 let mut transaction = self.begin().await?;
99 {
100 let mut stream = pin!(transaction.fetch(query));
101 while let Some(value) = stream.next().await.transpose()? {
102 yield value;
103 }
104 }
105 transaction.commit().await?;
106 }
107 .map_err(move |e: Error| {
108 let e = e.context(context.clone());
109 log::error!("{:#}", e);
110 e
111 }),
112 ),
113 }
114 }
115}
116
117impl Connection for PostgresConnection {
118 #[allow(refining_impl_trait)]
119 async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
120 let context = || format!("While trying to connect to `{}`", url);
121 let url = decode(&url).with_context(context)?;
122 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
123 if !url.starts_with(&prefix) {
124 let error = Error::msg(format!(
125 "Postgres connection url must start with `{}`",
126 &prefix
127 ))
128 .context(context());
129 log::error!("{:#}", error);
130 return Err(error);
131 }
132 let mut url = Url::parse(&url).with_context(context)?;
133 let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
134 let value = url
135 .query_pairs()
136 .find_map(|(k, v)| if k == key { Some(v) } else { None })
137 .map(|v| v.to_string());
138 if remove && let Some(..) = value {
139 let mut result = url.clone();
140 result.set_query(None);
141 result
142 .query_pairs_mut()
143 .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
144 url = result;
145 };
146 value.or_else(|| env::var(env_var).ok().map(Into::into))
147 };
148 let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
149 let (client, handle) = if sslmode == "disable" {
150 let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
151 let handle = spawn(async move {
152 if let Err(e) = connection.await
153 && !e.is_closed()
154 {
155 log::error!("Postgres connection error: {:#}", e);
156 }
157 });
158 (client, handle)
159 } else {
160 let mut builder = SslConnector::builder(SslMethod::tls())?;
161 let path = PathBuf::from_str(
162 take_url_param("sslrootcert", "PGSSLROOTCERT", true)
163 .as_deref()
164 .unwrap_or("~/.postgresql/root.crt"),
165 )
166 .context(context())?;
167 if path.exists() {
168 builder.set_ca_file(path)?;
169 }
170 let path = PathBuf::from_str(
171 take_url_param("sslcert", "PGSSLCERT", true)
172 .as_deref()
173 .unwrap_or("~/.postgresql/postgresql.crt"),
174 )
175 .context(context())?;
176 if path.exists() {
177 builder.set_certificate_chain_file(path)?;
178 }
179 let path = PathBuf::from_str(
180 take_url_param("sslkey", "PGSSLKEY", true)
181 .as_deref()
182 .unwrap_or("~/.postgresql/postgresql.key"),
183 )
184 .context(context())?;
185 if path.exists() {
186 builder.set_private_key_file(path, SslFiletype::PEM)?;
187 }
188 builder.set_verify(SslVerifyMode::PEER);
189 let connector = MakeTlsConnector::new(builder.build());
190 let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
191 let handle = spawn(async move {
192 if let Err(e) = connection.await
193 && !e.is_closed()
194 {
195 log::error!("Postgres connection error: {:#}", e);
196 }
197 });
198 (client, handle)
199 };
200 Ok(Self {
201 client,
202 handle,
203 _transaction: false,
204 })
205 }
206
207 #[allow(refining_impl_trait)]
208 fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
209 PostgresTransaction::new(self)
210 }
211
212 #[allow(refining_impl_trait)]
213 async fn disconnect(self) -> Result<()> {
214 drop(self.client);
215 if let Err(e) = self.handle.await {
216 let e = Error::new(e).context("While disconnecting from Postgres");
217 log::error!("{:#}", e);
218 return Err(e);
219 }
220 Ok(())
221 }
222}