simple_pg_client/
transaction.rs

1use crate::codec::FrontendMessage;
2use crate::connection::RequestMessages;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{
12    bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
13    SimpleQueryMessage, Statement, ToStatement,
14};
15use bytes::Buf;
16use futures_util::TryStreamExt;
17use postgres_protocol::message::frontend;
18use tokio::io::{AsyncRead, AsyncWrite};
19
20/// A representation of a PostgreSQL database transaction.
21///
22/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
23/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
24pub struct Transaction<'a> {
25    #[doc(hidden)]
26    pub client: &'a mut Client,
27    #[doc(hidden)]
28    pub returning_transaction_depth: u16,
29    #[doc(hidden)]
30    pub done: bool,
31}
32
33/// A representation of a PostgreSQL database savepoint.
34pub struct Savepoint {
35    #[doc(hidden)]
36    pub depth: u16,
37}
38
39impl<'a> Drop for Transaction<'a> {
40    fn drop(&mut self) {
41        if self.done {
42            return;
43        }
44
45        let query = if self.returning_transaction_depth > 0 {
46            format!("ROLLBACK TO sp_{}", self.returning_transaction_depth)
47        } else {
48            "ROLLBACK".to_string()
49        };
50
51        let buf = self.client.inner().with_buf(|buf| {
52            frontend::query(&query, buf).unwrap();
53            buf.split().freeze()
54        });
55        let _ = self
56            .client
57            .inner()
58            .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
59        self.client.transaction_depth = self
60            .client
61            .transaction_depth
62            .min(self.returning_transaction_depth);
63    }
64}
65
66impl<'a> Transaction<'a> {
67    pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
68        Transaction {
69            returning_transaction_depth: client.transaction_depth.saturating_sub(1),
70            client,
71            done: false,
72        }
73    }
74
75    /// Consumes the transaction, committing all changes made within it.
76    pub async fn commit(mut self) -> Result<(), Error> {
77        self.done = true;
78        let query = if self.returning_transaction_depth > 0 {
79            format!("RELEASE sp_{}", self.returning_transaction_depth)
80        } else {
81            "COMMIT".to_string()
82        };
83        self.client.batch_execute(&query).await
84    }
85
86    /// Rolls the transaction back, discarding all changes made within it.
87    ///
88    /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
89    pub async fn rollback(mut self) -> Result<(), Error> {
90        self.done = true;
91        let query = if self.returning_transaction_depth > 0 {
92            format!("ROLLBACK TO sp_{}", self.returning_transaction_depth)
93        } else {
94            "ROLLBACK".to_string()
95        };
96        self.client.batch_execute(&query).await?;
97        self.client.transaction_depth = self.returning_transaction_depth;
98        Ok(())
99    }
100
101    /// Like `Client::prepare`.
102    pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
103        self.client.prepare(query).await
104    }
105
106    /// Like `Client::prepare_typed`.
107    pub async fn prepare_typed(
108        &self,
109        query: &str,
110        parameter_types: &[Type],
111    ) -> Result<Statement, Error> {
112        self.client.prepare_typed(query, parameter_types).await
113    }
114
115    /// Like `Client::query`.
116    pub async fn query<T>(
117        &self,
118        statement: &T,
119        params: &[&(dyn ToSql + Sync)],
120    ) -> Result<Vec<Row>, Error>
121    where
122        T: ?Sized + ToStatement,
123    {
124        self.client.query(statement, params).await
125    }
126
127    /// Like `Client::query_one`.
128    pub async fn query_one<T>(
129        &self,
130        statement: &T,
131        params: &[&(dyn ToSql + Sync)],
132    ) -> Result<Row, Error>
133    where
134        T: ?Sized + ToStatement,
135    {
136        self.client.query_one(statement, params).await
137    }
138
139    /// Like `Client::query_opt`.
140    pub async fn query_opt<T>(
141        &self,
142        statement: &T,
143        params: &[&(dyn ToSql + Sync)],
144    ) -> Result<Option<Row>, Error>
145    where
146        T: ?Sized + ToStatement,
147    {
148        self.client.query_opt(statement, params).await
149    }
150
151    /// Like `Client::query_raw`.
152    pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
153    where
154        T: ?Sized + ToStatement,
155        P: BorrowToSql,
156        I: IntoIterator<Item = P>,
157        I::IntoIter: ExactSizeIterator,
158    {
159        self.client.query_raw(statement, params).await
160    }
161
162    /// Like `Client::execute`.
163    pub async fn execute<T>(
164        &self,
165        statement: &T,
166        params: &[&(dyn ToSql + Sync)],
167    ) -> Result<u64, Error>
168    where
169        T: ?Sized + ToStatement,
170    {
171        self.client.execute(statement, params).await
172    }
173
174    /// Like `Client::execute_iter`.
175    pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
176    where
177        T: ?Sized + ToStatement,
178        P: BorrowToSql,
179        I: IntoIterator<Item = P>,
180        I::IntoIter: ExactSizeIterator,
181    {
182        self.client.execute_raw(statement, params).await
183    }
184
185    /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
186    ///
187    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
188    /// connection that created them.
189    ///
190    /// # Panics
191    ///
192    /// Panics if the number of parameters provided does not match the number expected.
193    pub async fn bind<T>(
194        &self,
195        statement: &T,
196        params: &[&(dyn ToSql + Sync)],
197    ) -> Result<Portal, Error>
198    where
199        T: ?Sized + ToStatement,
200    {
201        self.bind_raw(statement, slice_iter(params)).await
202    }
203
204    /// A maximally flexible version of [`bind`].
205    ///
206    /// [`bind`]: #method.bind
207    pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
208    where
209        T: ?Sized + ToStatement,
210        P: BorrowToSql,
211        I: IntoIterator<Item = P>,
212        I::IntoIter: ExactSizeIterator,
213    {
214        let statement = statement.__convert().into_statement(self.client).await?;
215        bind::bind(self.client.inner(), statement, params).await
216    }
217
218    /// Continues execution of a portal, returning a stream of the resulting rows.
219    ///
220    /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
221    /// `query_portal`. If the requested number is negative or 0, all rows will be returned.
222    pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
223        self.query_portal_raw(portal, max_rows)
224            .await?
225            .try_collect()
226            .await
227    }
228
229    /// The maximally flexible version of [`query_portal`].
230    ///
231    /// [`query_portal`]: #method.query_portal
232    pub async fn query_portal_raw(
233        &self,
234        portal: &Portal,
235        max_rows: i32,
236    ) -> Result<RowStream, Error> {
237        query::query_portal(self.client.inner(), portal, max_rows).await
238    }
239
240    /// Like `Client::copy_in`.
241    pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
242    where
243        T: ?Sized + ToStatement,
244        U: Buf + 'static + Send,
245    {
246        self.client.copy_in(statement).await
247    }
248
249    /// Like `Client::copy_out`.
250    pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
251    where
252        T: ?Sized + ToStatement,
253    {
254        self.client.copy_out(statement).await
255    }
256
257    /// Like `Client::simple_query`.
258    pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
259        self.client.simple_query(query).await
260    }
261
262    /// Like `Client::batch_execute`.
263    pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
264        self.client.batch_execute(query).await
265    }
266
267    /// Like `Client::cancel_token`.
268    pub fn cancel_token(&self) -> CancelToken {
269        self.client.cancel_token()
270    }
271
272    /// Like `Client::cancel_query`.
273    #[cfg(feature = "runtime")]
274    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
275    pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
276    where
277        T: MakeTlsConnect<Socket>,
278    {
279        #[allow(deprecated)]
280        self.client.cancel_query(tls).await
281    }
282
283    /// Like `Client::cancel_query_raw`.
284    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
285    pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
286    where
287        S: AsyncRead + AsyncWrite + Unpin,
288        T: TlsConnect<S>,
289    {
290        #[allow(deprecated)]
291        self.client.cancel_query_raw(stream, tls).await
292    }
293
294    /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
295    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
296        self._savepoint().await
297    }
298
299    /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
300    pub async fn savepoint<I>(&mut self, _name: I) -> Result<Transaction<'_>, Error>
301    where
302        I: Into<String>,
303    {
304        self._savepoint().await
305    }
306
307    async fn _savepoint(&mut self) -> Result<Transaction<'_>, Error> {
308        self.client.transaction().await
309    }
310
311    /// Returns a reference to the underlying `Client`.
312    pub fn client(&self) -> &Client {
313        self.client
314    }
315}