1use crate::{
2 PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3 util::{
4 postgres_type_to_value, stream_postgres_row_to_tank_row,
5 stream_postgres_simple_query_message_to_tank_query_result, value_to_postgres_type,
6 },
7};
8use async_stream::try_stream;
9use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
10use postgres_openssl::MakeTlsConnector;
11use postgres_types::ToSql;
12use std::{
13 borrow::Cow,
14 env, mem,
15 path::PathBuf,
16 pin::{Pin, pin},
17 str::FromStr,
18};
19use tank_core::{
20 AsQuery, Connection, Driver, Entity, Error, ErrorContext, Executor, Query, QueryResult, Result,
21 RowsAffected, Transaction,
22 future::Either,
23 stream::{Stream, StreamExt, TryStreamExt},
24 truncate_long,
25};
26use tokio::{spawn, task::JoinHandle};
27use tokio_postgres::{NoTls, binary_copy::BinaryCopyInWriter};
28
29#[derive(Debug)]
34pub struct PostgresConnection {
35 pub(crate) client: tokio_postgres::Client,
36 pub(crate) handle: JoinHandle<()>,
37 pub(crate) _transaction: bool,
38}
39
40impl Executor for PostgresConnection {
41 type Driver = PostgresDriver;
42
43 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
44 let sql = sql.trim_end().trim_end_matches(';');
45 Ok(
46 PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
47 let error = Error::new(e).context(format!(
48 "While preparing the query:\n{}",
49 truncate_long!(sql)
50 ));
51 log::error!("{:#}", error);
52 error
53 })?)
54 .into(),
55 )
56 }
57
58 fn run<'s>(
59 &'s mut self,
60 query: impl AsQuery<Self::Driver> + 's,
61 ) -> impl Stream<Item = Result<QueryResult>> + Send {
62 let mut query = query.as_query();
63 let context = format!("While running the query:\n{}", query.as_mut());
64 let mut owned = mem::take(query.as_mut());
65 match owned {
66 Query::Raw(sql) => {
67 Either::Left(stream_postgres_simple_query_message_to_tank_query_result(
68 async move || self.client.simple_query_raw(&sql).await.map_err(Into::into),
69 ))
70 }
71 Query::Prepared(..) => Either::Right(try_stream! {
72 let mut transaction = self.begin().await?;
73 {
74 let mut stream = pin!(transaction.run(&mut owned));
75 while let Some(value) = stream.next().await.transpose()? {
76 yield value;
77 }
78 }
79 transaction.commit().await?;
80 *query.as_mut() = mem::take(&mut owned);
81 }),
82 }
83 .map_err(move |e: Error| {
84 let error = e.context(context.clone());
85 log::error!("{:#}", error);
86 error
87 })
88 }
89
90 fn fetch<'s>(
91 &'s mut self,
92 query: impl AsQuery<Self::Driver> + 's,
93 ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send {
94 let mut query = query.as_query();
95 let context = format!("While fetching the query:\n{}", query.as_mut());
96 let owned = mem::take(query.as_mut());
97 stream_postgres_row_to_tank_row(async move || {
98 let row_stream = match owned {
99 Query::Raw(mut sql) => {
100 let stream = self
101 .client
102 .query_raw(&sql, Vec::<ValueWrap>::new())
103 .await
104 .map_err(|e| Error::new(e).context(context.clone()))?;
105 *query.as_mut() = Query::Raw(mem::take(&mut sql));
106 stream
107 }
108 Query::Prepared(mut prepared) => {
109 let mut params = prepared.take_params();
110 let types = prepared.statement.params();
111
112 for (i, param) in params.iter_mut().enumerate() {
113 *param = ValueWrap(Cow::Owned(
114 mem::take(param)
115 .take_value()
116 .try_as(&postgres_type_to_value(&types[i]))?,
117 ));
118 }
119 let stream = self
120 .client
121 .query_raw(&prepared.statement, params)
122 .await
123 .map_err(|e| Error::new(e).context(context.clone()))?;
124 *query.as_mut() = Query::Prepared(prepared);
125 stream
126 }
127 };
128 Ok(row_stream).map_err(|e| {
129 log::error!("{:#}", e);
130 e
131 })
132 })
133 }
134
135 async fn append<'a, E, It>(&mut self, entities: It) -> Result<RowsAffected>
136 where
137 E: Entity + 'a,
138 It: IntoIterator<Item = &'a E> + Send,
139 <It as IntoIterator>::IntoIter: Send,
140 {
141 let context = || format!("While appending to the table `{}`", E::table().full_name());
142 let mut result = RowsAffected {
143 rows_affected: Some(0),
144 last_affected_id: None,
145 };
146 let writer = self.driver().sql_writer();
147 let mut sql = String::new();
148 writer.write_copy::<E>(&mut sql);
149 let sink = self.client.copy_in(&sql).await.with_context(context)?;
150 let types: Vec<_> = E::columns()
151 .into_iter()
152 .map(|c| value_to_postgres_type(&c.value))
153 .collect();
154 let writer = BinaryCopyInWriter::new(sink, &types);
155 let mut writer = pin!(writer);
156 let columns_len = E::columns().len();
157 let mut values = Vec::<ValueWrap>::with_capacity(columns_len);
158 let mut refs = Vec::<&(dyn ToSql + Sync)>::with_capacity(columns_len);
159 for entity in entities.into_iter() {
160 values.extend(
161 entity
162 .row_full()
163 .into_iter()
164 .map(|v| ValueWrap(Cow::Owned(v))),
165 );
166 refs.extend(
167 values
168 .iter()
169 .map(|v| unsafe { &*(v as &(dyn ToSql + Sync) as *const _) }),
170 );
171 Pin::as_mut(&mut writer)
172 .write(&refs)
173 .await
174 .with_context(context)?;
175 refs.clear();
176 values.clear();
177 *result.rows_affected.as_mut().unwrap() += 1;
178 }
179 writer.finish().await.with_context(context)?;
180 Ok(result)
181 }
182}
183
184impl Connection for PostgresConnection {
185 #[allow(refining_impl_trait)]
186 async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
187 let context = format!("While trying to connect to `{}`", truncate_long!(url));
188 let mut url = Self::sanitize_url(url)?;
189 let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
190 let value = url
191 .query_pairs()
192 .find_map(|(k, v)| if k == key { Some(v) } else { None })
193 .map(|v| v.to_string());
194 if remove && let Some(..) = value {
195 let mut result = url.clone();
196 result.set_query(None);
197 result
198 .query_pairs_mut()
199 .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
200 url = result;
201 };
202 value.or_else(|| env::var(env_var).ok().map(Into::into))
203 };
204 let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
205 let (client, handle) = if sslmode == "disable" {
206 let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
207 let handle = spawn(async move {
208 if let Err(error) = connection.await
209 && !error.is_closed()
210 {
211 log::error!("Postgres connection error: {:#?}", error);
212 }
213 });
214 (client, handle)
215 } else {
216 let mut builder = SslConnector::builder(SslMethod::tls())?;
217 let path = PathBuf::from_str(
218 take_url_param("sslrootcert", "PGSSLROOTCERT", true)
219 .as_deref()
220 .unwrap_or("~/.postgresql/root.crt"),
221 )
222 .with_context(|| context.clone())?;
223 if path.exists() {
224 builder.set_ca_file(path)?;
225 }
226 let path = PathBuf::from_str(
227 take_url_param("sslcert", "PGSSLCERT", true)
228 .as_deref()
229 .unwrap_or("~/.postgresql/postgresql.crt"),
230 )
231 .with_context(|| context.clone())?;
232 if path.exists() {
233 builder.set_certificate_chain_file(path)?;
234 }
235 let path = PathBuf::from_str(
236 take_url_param("sslkey", "PGSSLKEY", true)
237 .as_deref()
238 .unwrap_or("~/.postgresql/postgresql.key"),
239 )
240 .with_context(|| context.clone())?;
241 if path.exists() {
242 builder.set_private_key_file(path, SslFiletype::PEM)?;
243 }
244 builder.set_verify(SslVerifyMode::PEER);
245 let connector = MakeTlsConnector::new(builder.build());
246 let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
247 let handle = spawn(async move {
248 if let Err(error) = connection.await
249 && !error.is_closed()
250 {
251 log::error!("Postgres connection error: {:#?}", error);
252 }
253 });
254 (client, handle)
255 };
256 Ok(Self {
257 client,
258 handle,
259 _transaction: false,
260 })
261 }
262
263 #[allow(refining_impl_trait)]
264 fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
265 PostgresTransaction::new(self)
266 }
267
268 #[allow(refining_impl_trait)]
269 async fn disconnect(self) -> Result<()> {
270 drop(self.client);
271 if let Err(e) = self.handle.await {
272 let error = Error::new(e).context("While disconnecting from Postgres");
273 log::error!("{:#}", error);
274 return Err(error);
275 }
276 Ok(())
277 }
278}