tank_postgres/
connection.rs

1use crate::{
2    PostgresDriver, PostgresPrepared, PostgresTransaction, ValueHolder,
3    util::{
4        stream_postgres_row_to_tank_row, stream_postgres_simple_query_message_to_tank_query_result,
5    },
6};
7use async_stream::try_stream;
8use std::{borrow::Cow, pin::pin, sync::Arc};
9use tank_core::{
10    Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, Transaction,
11    future::Either,
12    printable_query,
13    stream::{Stream, StreamExt, TryStreamExt},
14};
15use tokio::spawn;
16use tokio_postgres::NoTls;
17
18pub struct PostgresConnection {
19    pub(crate) client: tokio_postgres::Client,
20    pub(crate) _transaction: bool,
21}
22
23impl Executor for PostgresConnection {
24    type Driver = PostgresDriver;
25
26    fn driver(&self) -> &Self::Driver {
27        &PostgresDriver {}
28    }
29
30    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
31        let sql = sql.trim_end().trim_end_matches(';');
32        Ok(PostgresPrepared::new(
33            self.client.prepare(&sql).await.with_context(|| {
34                format!("While preparing the query:\n{}", printable_query!(sql))
35            })?,
36        )
37        .into())
38    }
39
40    fn run(
41        &mut self,
42        query: Query<Self::Driver>,
43    ) -> impl Stream<Item = Result<QueryResult>> + Send {
44        let context = Arc::new(format!("While running the query:\n{}", query));
45        match query {
46            Query::Raw(sql) => Either::Left(
47                stream_postgres_simple_query_message_to_tank_query_result(async move || {
48                    self.client.simple_query_raw(&sql).await.map_err(Error::new)
49                })
50                .map_err(move |e| e.context(context.clone())),
51            ),
52            Query::Prepared(..) => Either::Right(try_stream! {
53                let mut transaction = self.begin().await?;
54                {
55                    let stream = transaction.run(query);
56                    let mut stream = pin!(stream);
57                    while let Some(value) = stream.next().await.transpose()? {
58                        yield value;
59                    }
60                }
61                transaction.commit().await?;
62            }),
63        }
64    }
65
66    fn fetch<'s>(
67        &'s mut self,
68        query: Query<Self::Driver>,
69    ) -> impl Stream<Item = Result<tank_core::RowLabeled>> + Send + 's {
70        let context = Arc::new(format!("While fetching the query:\n{}", query));
71        match query {
72            Query::Raw(sql) => Either::Left(stream_postgres_row_to_tank_row(async move || {
73                self.client
74                    .query_raw(&sql, Vec::<ValueHolder>::new())
75                    .await
76                    .map_err(Error::new)
77                    .context(context)
78            })),
79            Query::Prepared(..) => Either::Right(
80                try_stream! {
81                    let mut transaction = self.begin().await?;
82                    {
83                        let stream = transaction.fetch(query);
84                        let mut stream = pin!(stream);
85                        while let Some(value) = stream.next().await.transpose()? {
86                            yield value;
87                        }
88                    }
89                    transaction.commit().await?;
90                }
91                .map_err(move |e: Error| e.context(context.clone())),
92            ),
93        }
94    }
95}
96
97impl Connection for PostgresConnection {
98    #[allow(refining_impl_trait)]
99    async fn connect(url: Cow<'static, str>) -> Result<PostgresConnection> {
100        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
101        if !url.starts_with(&prefix) {
102            let error = Error::msg(format!(
103                "Postgres connection url must start with `{}`",
104                &prefix
105            ));
106            log::error!("{:#}", error);
107            return Err(error);
108        }
109        let (client, connection) = tokio_postgres::connect(&url, NoTls)
110            .await
111            .with_context(|| format!("While trying to connect to `{}`", url))?;
112        spawn(async move {
113            if let Err(e) = connection.await {
114                log::error!("Postgres connection error: {:#}", e);
115            }
116        });
117
118        Ok(Self {
119            client,
120            _transaction: false,
121        })
122    }
123
124    #[allow(refining_impl_trait)]
125    fn begin(&mut self) -> impl Future<Output = Result<PostgresTransaction<'_>>> {
126        PostgresTransaction::new(self)
127    }
128}