1use crate::{
2 PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3 util::{row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result},
4};
5use async_stream::try_stream;
6use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
7use postgres_openssl::MakeTlsConnector;
8use std::{borrow::Cow, env, mem, path::PathBuf, pin::pin, str::FromStr, sync::Arc};
9use tank_core::{
10 AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
11 Transaction,
12 future::Either,
13 stream::{Stream, StreamExt, TryStreamExt},
14 truncate_long,
15};
16use tokio::{spawn, task::JoinHandle};
17use tokio_postgres::NoTls;
18use url::Url;
19use urlencoding::decode;
20
21#[derive(Debug)]
22pub struct PostgresConnection {
23 pub(crate) client: tokio_postgres::Client,
24 pub(crate) handle: JoinHandle<()>,
25 pub(crate) _transaction: bool,
26}
27
28impl Executor for PostgresConnection {
29 type Driver = PostgresDriver;
30
31 fn driver(&self) -> &Self::Driver {
32 &PostgresDriver {}
33 }
34
35 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
36 let sql = sql.trim_end().trim_end_matches(';');
37 Ok(
38 PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
39 let error = Error::new(e).context(format!(
40 "While preparing the query:\n{}",
41 truncate_long!(sql)
42 ));
43 log::error!("{:#}", error);
44 error
45 })?)
46 .into(),
47 )
48 }
49
50 fn run<'s>(
51 &'s mut self,
52 query: impl AsQuery<Self::Driver> + 's,
53 ) -> impl Stream<Item = Result<QueryResult>> + Send {
54 let mut query = query.as_query();
55 let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
56 let mut owned = mem::take(query.as_mut());
57 match owned {
58 Query::Raw(sql) => {
59 Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
60 async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
61 ))
62 }
63 Query::Prepared(..) => Either::Right(try_stream! {
64 let mut transaction = self.begin().await?;
65 {
66 let mut stream = pin!(transaction.run(&mut owned));
67 while let Some(value) = stream.next().await.transpose()? {
68 yield value;
69 }
70 }
71 transaction.commit().await?;
72 *query.as_mut() = mem::take(&mut owned);
73 }),
74 }
75 .map_err(move |e: Error| {
76 let error = e.context(context.clone());
77 log::error!("{:#}", error);
78 error
79 })
80 }
81
82 fn fetch<'s>(
83 &'s mut self,
84 query: impl AsQuery<Self::Driver> + 's,
85 ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
86 let mut query = query.as_query();
87 let context = Arc::new(format!("While fetching the query:\n{}", query.as_mut()));
88 let owned = mem::take(query.as_mut());
89 match owned {
90 Query::Raw(sql) => {
91 let stream = async move || {
92 self.client
93 .query_raw(&sql, Vec::<ValueWrap>::new())
94 .await
95 .map_err(|e| {
96 let error = Error::new(e).context(context.clone());
97 log::error!("{:#}", error);
98 error
99 })
100 };
101 let stream = try_stream! {
102 let stream = stream().await?;
103 let mut stream = pin!(stream);
104 let mut labels: Option<tank_core::RowNames> = None;
105 while let Some(row) = stream.next().await.transpose()? {
106 let labels = labels.get_or_insert_with(|| {
107 row.columns().iter().map(|c| c.name().to_string()).collect()
108 });
109 yield tank_core::RowLabeled {
110 labels: labels.clone(),
111 values: row_to_tank_row(row)?.into(),
112 };
113 }
114 };
115 Either::Left(stream)
116 }
117 Query::Prepared(prepared) => Either::Right(
118 try_stream! {
119 let mut transaction = self.begin().await?;
120 {
121 let mut query = Query::Prepared(prepared);
122 let mut stream = pin!(transaction.fetch(&mut query));
123 while let Some(value) = stream.next().await.transpose()? {
124 yield value;
125 }
126 }
127 transaction.commit().await?;
128 }
129 .map_err(move |e: Error| {
130 let error = e.context(context.clone());
131 log::error!("{:#}", error);
132 error
133 }),
134 ),
135 }
136 }
137}
138
139impl Connection for PostgresConnection {
140 #[allow(refining_impl_trait)]
141 async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
142 let context = || format!("While trying to connect to `{}`", truncate_long!(url));
143 let url = decode(&url).with_context(context)?;
144 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
145 if !url.starts_with(&prefix) {
146 let error = Error::msg(format!(
147 "Postgres connection url must start with `{}`",
148 &prefix
149 ))
150 .context(context());
151 log::error!("{:#}", error);
152 return Err(error);
153 }
154 let mut url = Url::parse(&url).with_context(context)?;
155 let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
156 let value = url
157 .query_pairs()
158 .find_map(|(k, v)| if k == key { Some(v) } else { None })
159 .map(|v| v.to_string());
160 if remove && let Some(..) = value {
161 let mut result = url.clone();
162 result.set_query(None);
163 result
164 .query_pairs_mut()
165 .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
166 url = result;
167 };
168 value.or_else(|| env::var(env_var).ok().map(Into::into))
169 };
170 let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
171 let (client, handle) = if sslmode == "disable" {
172 let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
173 let handle = spawn(async move {
174 if let Err(error) = connection.await
175 && !error.is_closed()
176 {
177 log::error!("Postgres connection error: {:#?}", error);
178 }
179 });
180 (client, handle)
181 } else {
182 let mut builder = SslConnector::builder(SslMethod::tls())?;
183 let path = PathBuf::from_str(
184 take_url_param("sslrootcert", "PGSSLROOTCERT", true)
185 .as_deref()
186 .unwrap_or("~/.postgresql/root.crt"),
187 )
188 .context(context())?;
189 if path.exists() {
190 builder.set_ca_file(path)?;
191 }
192 let path = PathBuf::from_str(
193 take_url_param("sslcert", "PGSSLCERT", true)
194 .as_deref()
195 .unwrap_or("~/.postgresql/postgresql.crt"),
196 )
197 .context(context())?;
198 if path.exists() {
199 builder.set_certificate_chain_file(path)?;
200 }
201 let path = PathBuf::from_str(
202 take_url_param("sslkey", "PGSSLKEY", true)
203 .as_deref()
204 .unwrap_or("~/.postgresql/postgresql.key"),
205 )
206 .context(context())?;
207 if path.exists() {
208 builder.set_private_key_file(path, SslFiletype::PEM)?;
209 }
210 builder.set_verify(SslVerifyMode::PEER);
211 let connector = MakeTlsConnector::new(builder.build());
212 let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
213 let handle = spawn(async move {
214 if let Err(error) = connection.await
215 && !error.is_closed()
216 {
217 log::error!("Postgres connection error: {:#?}", error);
218 }
219 });
220 (client, handle)
221 };
222 Ok(Self {
223 client,
224 handle,
225 _transaction: false,
226 })
227 }
228
229 #[allow(refining_impl_trait)]
230 fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
231 PostgresTransaction::new(self)
232 }
233
234 #[allow(refining_impl_trait)]
235 async fn disconnect(self) -> Result<()> {
236 drop(self.client);
237 if let Err(e) = self.handle.await {
238 let error = Error::new(e).context("While disconnecting from Postgres");
239 log::error!("{:#}", error);
240 return Err(error);
241 }
242 Ok(())
243 }
244}