tank_postgres/
transaction.rs

1use crate::{
2    PostgresConnection, PostgresDriver, PostgresPrepared, ValueHolder,
3    util::stream_postgres_row_to_tank_row,
4};
5use tank_core::{
6    Error, Executor, Query, QueryResult, Result, Transaction,
7    future::{Either, TryFutureExt},
8    stream::Stream,
9};
10
11pub struct PostgresTransaction<'c>(pub(crate) tokio_postgres::Transaction<'c>);
12
13impl<'c> PostgresTransaction<'c> {
14    pub async fn new(client: &'c mut PostgresConnection) -> Result<Self> {
15        Ok(Self(client.client.transaction().await?))
16    }
17}
18
19impl<'c> Executor for PostgresTransaction<'c> {
20    type Driver = PostgresDriver;
21    fn driver(&self) -> &Self::Driver {
22        &PostgresDriver {}
23    }
24    async fn prepare(&mut self, query: String) -> Result<Query<Self::Driver>> {
25        Ok(PostgresPrepared::new(self.0.prepare(&query).await?).into())
26    }
27    fn run(
28        &mut self,
29        query: Query<Self::Driver>,
30    ) -> impl Stream<Item = Result<QueryResult>> + Send {
31        stream_postgres_row_to_tank_row(async move || match query {
32            Query::Raw(sql) => Ok(Either::Left(
33                self.0.query_raw(&sql, Vec::<ValueHolder>::new()).await?,
34            )),
35            Query::Prepared(mut prepared) => {
36                let portal = if !prepared.is_complete() {
37                    prepared.complete(self).await?
38                } else {
39                    prepared.get_portal().ok_or(Error::msg(format!(
40                        "The prepared statement `{}` is not complete",
41                        prepared
42                    )))?
43                };
44                Ok(Either::Right(self.0.query_portal_raw(&portal, 0).await?))
45            }
46        })
47    }
48}
49
50impl<'c> Transaction<'c> for PostgresTransaction<'c> {
51    fn commit(self) -> impl Future<Output = Result<()>> {
52        self.0.commit().map_err(Into::into)
53    }
54    fn rollback(self) -> impl Future<Output = Result<()>> {
55        self.0.rollback().map_err(Into::into)
56    }
57}