tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueWrap,
3    util::{
4        postgres_type_to_value, stream_postgres_row_to_tank_row,
5        stream_postgres_simple_query_message_to_tank_query_result, value_to_postgres_type,
6    },
7};
8use async_stream::try_stream;
9use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
10use postgres_openssl::MakeTlsConnector;
11use postgres_types::ToSql;
12use std::{
13    borrow::Cow,
14    env, mem,
15    path::PathBuf,
16    pin::{Pin, pin},
17    str::FromStr,
18};
19use tank_core::{
20    AsQuery, Connection, Driver, DynQuery, Entity, Error, ErrorContext, Executor, Query,
21    QueryResult, Result, RowsAffected, Transaction,
22    future::Either,
23    stream::{Stream, StreamExt, TryStreamExt},
24    truncate_long,
25};
26use tokio::{spawn, task::JoinHandle};
27use tokio_postgres::{NoTls, binary_copy::BinaryCopyInWriter};
28
29/// Connection wrapper for Postgres/Tokio Postgres client.
30///
31/// Manages the client handle and background task used to drive the connection
32/// and implements `Executor` for running queries against Postgres.
33#[derive(Debug)]
34pub struct PostgresConnection {
35    pub(crate) client: tokio_postgres::Client,
36    pub(crate) handle: JoinHandle<()>,
37    pub(crate) _transaction: bool,
38}
39
40impl Executor for PostgresConnection {
41    type Driver = PostgresDriver;
42
43    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
44        let sql = sql.as_str().trim_end().trim_end_matches(';');
45        Ok(
46            PostgresPrepared::new(self.client.prepare(&sql).await.map_err(|e| {
47                let error = Error::new(e).context(format!(
48                    "While preparing the query:\n{}",
49                    truncate_long!(sql)
50                ));
51                log::error!("{:#}", error);
52                error
53            })?)
54            .into(),
55        )
56    }
57
58    fn run<'s>(
59        &'s mut self,
60        query: impl AsQuery<Self::Driver> + 's,
61    ) -> impl Stream<Item = Result<QueryResult>> + Send {
62        let mut query = query.as_query();
63        let context = format!("While running the query:\n{}", query.as_mut());
64        let mut owned = mem::take(query.as_mut());
65        match owned {
66            Query::Raw(raw) => Either::Left(try_stream! {
67                let sql = &raw.sql;
68                {
69                    let stream = stream_postgres_simple_query_message_to_tank_query_result(
70                        async move || self.client.simple_query_raw(sql).await.map_err(Into::into),
71                    );
72                    let mut stream = pin!(stream);
73                    while let Some(value) = stream.next().await.transpose()? {
74                        yield value;
75                    }
76                }
77                *query.as_mut() = Query::Raw(raw);
78            }),
79            Query::Prepared(..) => Either::Right(try_stream! {
80                let mut transaction = self.begin().await?;
81                {
82                    let mut stream = pin!(transaction.run(&mut owned));
83                    while let Some(value) = stream.next().await.transpose()? {
84                        yield value;
85                    }
86                }
87                transaction.commit().await?;
88                *query.as_mut() = mem::take(&mut owned);
89            }),
90        }
91        .map_err(move |e: Error| {
92            let error = e.context(context.clone());
93            log::error!("{:#}", error);
94            error
95        })
96    }
97
98    fn fetch<'s>(
99        &'s mut self,
100        query: impl AsQuery<Self::Driver> + 's,
101    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send {
102        let mut query = query.as_query();
103        let context = format!("While fetching the query:\n{}", query.as_mut());
104        let owned = mem::take(query.as_mut());
105        stream_postgres_row_to_tank_row(async move || {
106            let row_stream = match owned {
107                Query::Raw(raw) => {
108                    let stream = self
109                        .client
110                        .query_raw(&raw.sql, Vec::<ValueWrap>::new())
111                        .await
112                        .map_err(|e| Error::new(e).context(context.clone()))?;
113                    *query.as_mut() = Query::Raw(raw);
114                    stream
115                }
116                Query::Prepared(mut prepared) => {
117                    let mut params = prepared.take_params();
118                    let types = prepared.statement.params();
119
120                    for (i, param) in params.iter_mut().enumerate() {
121                        *param = ValueWrap(Cow::Owned(
122                            mem::take(param)
123                                .take_value()
124                                .try_as(&postgres_type_to_value(&types[i]))?,
125                        ));
126                    }
127                    let stream = self
128                        .client
129                        .query_raw(&prepared.statement, params)
130                        .await
131                        .map_err(|e| Error::new(e).context(context.clone()))?;
132                    *query.as_mut() = Query::Prepared(prepared);
133                    stream
134                }
135            };
136            Ok(row_stream).map_err(|e| {
137                log::error!("{:#}", e);
138                e
139            })
140        })
141    }
142
143    async fn append<'a, E, It>(&mut self, entities: It) -> Result<RowsAffected>
144    where
145        E: Entity + 'a,
146        It: IntoIterator<Item = &'a E> + Send,
147        <It as IntoIterator>::IntoIter: Send,
148    {
149        let context = || format!("While appending to the table `{}`", E::table().full_name());
150        let mut result = RowsAffected {
151            rows_affected: Some(0),
152            last_affected_id: None,
153        };
154        let writer = self.driver().sql_writer();
155        let mut query = DynQuery::default();
156        writer.write_copy::<E>(&mut query);
157        let sink = self
158            .client
159            .copy_in(query.as_str())
160            .await
161            .with_context(context)?;
162        let types: Vec<_> = E::columns()
163            .into_iter()
164            .map(|c| value_to_postgres_type(&c.value))
165            .collect();
166        let writer = BinaryCopyInWriter::new(sink, &types);
167        let mut writer = pin!(writer);
168        let columns_len = E::columns().len();
169        let mut values = Vec::<ValueWrap>::with_capacity(columns_len);
170        let mut refs = Vec::<&(dyn ToSql + Sync)>::with_capacity(columns_len);
171        for entity in entities.into_iter() {
172            values.extend(
173                entity
174                    .row_full()
175                    .into_iter()
176                    .map(|v| ValueWrap(Cow::Owned(v))),
177            );
178            refs.extend(
179                values
180                    .iter()
181                    .map(|v| unsafe { &*(v as &(dyn ToSql + Sync) as *const _) }),
182            );
183            Pin::as_mut(&mut writer)
184                .write(&refs)
185                .await
186                .with_context(context)?;
187            refs.clear();
188            values.clear();
189            *result.rows_affected.as_mut().unwrap() += 1;
190        }
191        writer.finish().await.with_context(context)?;
192        Ok(result)
193    }
194}
195
196impl Connection for PostgresConnection {
197    async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
198        let context = format!("While trying to connect to `{}`", truncate_long!(url));
199        let mut url = Self::sanitize_url(url)?;
200        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
201            let value = url
202                .query_pairs()
203                .find_map(|(k, v)| if k == key { Some(v) } else { None })
204                .map(|v| v.to_string());
205            if remove && let Some(..) = value {
206                let mut result = url.clone();
207                result.set_query(None);
208                result
209                    .query_pairs_mut()
210                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
211                url = result;
212            };
213            value.or_else(|| env::var(env_var).ok().map(Into::into))
214        };
215        let sslmode = take_url_param("sslmode", "PGSSLMODE", false).unwrap_or("disable".into());
216        let (client, handle) = if sslmode == "disable" {
217            let (client, connection) = tokio_postgres::connect(url.as_str(), NoTls).await?;
218            let handle = spawn(async move {
219                if let Err(error) = connection.await
220                    && !error.is_closed()
221                {
222                    log::error!("Postgres connection error: {:#?}", error);
223                }
224            });
225            (client, handle)
226        } else {
227            let mut builder = SslConnector::builder(SslMethod::tls())?;
228            let path = PathBuf::from_str(
229                take_url_param("sslrootcert", "PGSSLROOTCERT", true)
230                    .as_deref()
231                    .unwrap_or("~/.postgresql/root.crt"),
232            )
233            .with_context(|| context.clone())?;
234            if path.exists() {
235                builder.set_ca_file(path)?;
236            }
237            let path = PathBuf::from_str(
238                take_url_param("sslcert", "PGSSLCERT", true)
239                    .as_deref()
240                    .unwrap_or("~/.postgresql/postgresql.crt"),
241            )
242            .with_context(|| context.clone())?;
243            if path.exists() {
244                builder.set_certificate_chain_file(path)?;
245            }
246            let path = PathBuf::from_str(
247                take_url_param("sslkey", "PGSSLKEY", true)
248                    .as_deref()
249                    .unwrap_or("~/.postgresql/postgresql.key"),
250            )
251            .with_context(|| context.clone())?;
252            if path.exists() {
253                builder.set_private_key_file(path, SslFiletype::PEM)?;
254            }
255            builder.set_verify(SslVerifyMode::PEER);
256            let connector = MakeTlsConnector::new(builder.build());
257            let (client, connection) = tokio_postgres::connect(url.as_str(), connector).await?;
258            let handle = spawn(async move {
259                if let Err(error) = connection.await
260                    && !error.is_closed()
261                {
262                    log::error!("Postgres connection error: {:#?}", error);
263                }
264            });
265            (client, handle)
266        };
267        Ok(Self {
268            client,
269            handle,
270            _transaction: false,
271        })
272    }
273
274    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
275        PostgresTransaction::new(self)
276    }
277
278    async fn disconnect(self) -> Result<()> {
279        drop(self.client);
280        if let Err(e) = self.handle.await {
281            let error = Error::new(e).context("While disconnecting from Postgres");
282            log::error!("{:#}", error);
283            return Err(error);
284        }
285        Ok(())
286    }
287}