tank_postgres/
transaction.rs

1use crate::{
2    PostgresConnection, PostgresDriver, PostgresPrepared, ValueHolder, util::row_to_tank_row,
3};
4use async_stream::try_stream;
5use std::pin::pin;
6use tank_core::{
7    Error, Executor, Query, QueryResult, Result, RowLabeled, Transaction,
8    future::{Either, TryFutureExt},
9    stream::{Stream, StreamExt},
10};
11
12pub struct PostgresTransaction<'c>(pub(crate) tokio_postgres::Transaction<'c>);
13
14impl<'c> PostgresTransaction<'c> {
15    pub async fn new(client: &'c mut PostgresConnection) -> Result<Self> {
16        Ok(Self(client.client.transaction().await?))
17    }
18}
19
20impl<'c> Executor for PostgresTransaction<'c> {
21    type Driver = PostgresDriver;
22    fn driver(&self) -> &Self::Driver {
23        &PostgresDriver {}
24    }
25    async fn prepare(&mut self, query: String) -> Result<Query<Self::Driver>> {
26        Ok(PostgresPrepared::new(self.0.prepare(&query).await?).into())
27    }
28    fn run(
29        &mut self,
30        query: Query<Self::Driver>,
31    ) -> impl Stream<Item = Result<QueryResult>> + Send {
32        try_stream! {
33            let stream = match query {
34                Query::Raw(sql) => {
35                    Either::Left(self.0.query_raw(&sql, Vec::<ValueHolder>::new()).await?)
36                }
37                Query::Prepared(mut prepared) => {
38                    let portal = if !prepared.is_complete() {
39                        prepared.complete(self).await?
40                    } else {
41                        prepared.get_portal().ok_or(Error::msg(format!(
42                            "The prepared statement `{}` is not complete",
43                            prepared
44                        )))?
45                    };
46                    Either::Right(self.0.query_portal_raw(&portal, 0).await?)
47                }
48            };
49            let mut stream = pin!(stream);
50            if let Some(first) = stream.next().await {
51                let labels = first?
52                    .columns()
53                    .iter()
54                    .map(|c| c.name().to_string())
55                    .collect::<tank_core::RowNames>();
56                while let Some(value) = stream.next().await {
57                    yield RowLabeled {
58                        labels: labels.clone(),
59                        values: row_to_tank_row(value?).into(),
60                    }
61                    .into()
62                }
63            }
64        }
65    }
66}
67
68impl<'c> Transaction<'c> for PostgresTransaction<'c> {
69    fn commit(self) -> impl Future<Output = Result<()>> {
70        self.0.commit().map_err(Into::into)
71    }
72    fn rollback(self) -> impl Future<Output = Result<()>> {
73        self.0.rollback().map_err(Into::into)
74    }
75}