yb_postgres/
transaction.rs

1use crate::connection::ConnectionRef;
2use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
3use yb_tokio_postgres::types::{BorrowToSql, ToSql, Type};
4use yb_tokio_postgres::{Error, Row, SimpleQueryMessage};
5
6/// A representation of a PostgreSQL database transaction.
7///
8/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
9/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
10pub struct Transaction<'a> {
11    connection: ConnectionRef<'a>,
12    transaction: Option<yb_tokio_postgres::Transaction<'a>>,
13}
14
15impl<'a> Drop for Transaction<'a> {
16    fn drop(&mut self) {
17        if let Some(transaction) = self.transaction.take() {
18            let _ = self.connection.block_on(transaction.rollback());
19        }
20    }
21}
22
23impl<'a> Transaction<'a> {
24    pub(crate) fn new(
25        connection: ConnectionRef<'a>,
26        transaction: yb_tokio_postgres::Transaction<'a>,
27    ) -> Transaction<'a> {
28        Transaction {
29            connection,
30            transaction: Some(transaction),
31        }
32    }
33
34    /// Consumes the transaction, committing all changes made within it.
35    pub fn commit(mut self) -> Result<(), Error> {
36        self.connection
37            .block_on(self.transaction.take().unwrap().commit())
38    }
39
40    /// Rolls the transaction back, discarding all changes made within it.
41    ///
42    /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
43    pub fn rollback(mut self) -> Result<(), Error> {
44        self.connection
45            .block_on(self.transaction.take().unwrap().rollback())
46    }
47
48    /// Like `Client::prepare`.
49    pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
50        self.connection
51            .block_on(self.transaction.as_ref().unwrap().prepare(query))
52    }
53
54    /// Like `Client::prepare_typed`.
55    pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
56        self.connection.block_on(
57            self.transaction
58                .as_ref()
59                .unwrap()
60                .prepare_typed(query, types),
61        )
62    }
63
64    /// Like `Client::execute`.
65    pub fn execute<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
66    where
67        T: ?Sized + ToStatement,
68    {
69        self.connection
70            .block_on(self.transaction.as_ref().unwrap().execute(query, params))
71    }
72
73    /// Like `Client::query`.
74    pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
75    where
76        T: ?Sized + ToStatement,
77    {
78        self.connection
79            .block_on(self.transaction.as_ref().unwrap().query(query, params))
80    }
81
82    /// Like `Client::query_one`.
83    pub fn query_one<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error>
84    where
85        T: ?Sized + ToStatement,
86    {
87        self.connection
88            .block_on(self.transaction.as_ref().unwrap().query_one(query, params))
89    }
90
91    /// Like `Client::query_opt`.
92    pub fn query_opt<T>(
93        &mut self,
94        query: &T,
95        params: &[&(dyn ToSql + Sync)],
96    ) -> Result<Option<Row>, Error>
97    where
98        T: ?Sized + ToStatement,
99    {
100        self.connection
101            .block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
102    }
103
104    /// Like `Client::query_raw`.
105    pub fn query_raw<T, P, I>(&mut self, query: &T, params: I) -> Result<RowIter<'_>, Error>
106    where
107        T: ?Sized + ToStatement,
108        P: BorrowToSql,
109        I: IntoIterator<Item = P>,
110        I::IntoIter: ExactSizeIterator,
111    {
112        let stream = self
113            .connection
114            .block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
115        Ok(RowIter::new(self.connection.as_ref(), stream))
116    }
117
118    /// Binds parameters to a statement, creating a "portal".
119    ///
120    /// Portals can be used with the `query_portal` method to page through the results of a query without being forced
121    /// to consume them all immediately.
122    ///
123    /// Portals are automatically closed when the transaction they were created in is closed.
124    ///
125    /// # Panics
126    ///
127    /// Panics if the number of parameters provided does not match the number expected.
128    pub fn bind<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Portal, Error>
129    where
130        T: ?Sized + ToStatement,
131    {
132        self.connection
133            .block_on(self.transaction.as_ref().unwrap().bind(query, params))
134    }
135
136    /// Continues execution of a portal, returning the next set of rows.
137    ///
138    /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
139    /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
140    pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
141        self.connection.block_on(
142            self.transaction
143                .as_ref()
144                .unwrap()
145                .query_portal(portal, max_rows),
146        )
147    }
148
149    /// The maximally flexible version of `query_portal`.
150    pub fn query_portal_raw(
151        &mut self,
152        portal: &Portal,
153        max_rows: i32,
154    ) -> Result<RowIter<'_>, Error> {
155        let stream = self.connection.block_on(
156            self.transaction
157                .as_ref()
158                .unwrap()
159                .query_portal_raw(portal, max_rows),
160        )?;
161        Ok(RowIter::new(self.connection.as_ref(), stream))
162    }
163
164    /// Like `Client::copy_in`.
165    pub fn copy_in<T>(&mut self, query: &T) -> Result<CopyInWriter<'_>, Error>
166    where
167        T: ?Sized + ToStatement,
168    {
169        let sink = self
170            .connection
171            .block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
172        Ok(CopyInWriter::new(self.connection.as_ref(), sink))
173    }
174
175    /// Like `Client::copy_out`.
176    pub fn copy_out<T>(&mut self, query: &T) -> Result<CopyOutReader<'_>, Error>
177    where
178        T: ?Sized + ToStatement,
179    {
180        let stream = self
181            .connection
182            .block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
183        Ok(CopyOutReader::new(self.connection.as_ref(), stream))
184    }
185
186    /// Like `Client::simple_query`.
187    pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
188        self.connection
189            .block_on(self.transaction.as_ref().unwrap().simple_query(query))
190    }
191
192    /// Like `Client::batch_execute`.
193    pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
194        self.connection
195            .block_on(self.transaction.as_ref().unwrap().batch_execute(query))
196    }
197
198    /// Like `Client::cancel_token`.
199    pub fn cancel_token(&self) -> CancelToken {
200        CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
201    }
202
203    /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
204    pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
205        let transaction = self
206            .connection
207            .block_on(self.transaction.as_mut().unwrap().transaction())?;
208        Ok(Transaction::new(self.connection.as_ref(), transaction))
209    }
210
211    /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
212    pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
213    where
214        I: Into<String>,
215    {
216        let transaction = self
217            .connection
218            .block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
219        Ok(Transaction::new(self.connection.as_ref(), transaction))
220    }
221}