1use crate::codec::FrontendMessage;
2use crate::connection::RequestMessages;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{
12 bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
13 SimpleQueryMessage, Statement, ToStatement,
14};
15use bytes::Buf;
16use futures_util::TryStreamExt;
17use postgres_protocol::message::frontend;
18use tokio::io::{AsyncRead, AsyncWrite};
19
20pub struct Transaction<'a> {
25 #[doc(hidden)]
26 pub client: &'a mut Client,
27 #[doc(hidden)]
28 pub returning_transaction_depth: u16,
29 #[doc(hidden)]
30 pub done: bool,
31}
32
33pub struct Savepoint {
35 #[doc(hidden)]
36 pub depth: u16,
37}
38
39impl<'a> Drop for Transaction<'a> {
40 fn drop(&mut self) {
41 if self.done {
42 return;
43 }
44
45 let query = if self.returning_transaction_depth > 0 {
46 format!("ROLLBACK TO sp_{}", self.returning_transaction_depth)
47 } else {
48 "ROLLBACK".to_string()
49 };
50
51 let buf = self.client.inner().with_buf(|buf| {
52 frontend::query(&query, buf).unwrap();
53 buf.split().freeze()
54 });
55 let _ = self
56 .client
57 .inner()
58 .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
59 self.client.transaction_depth = self
60 .client
61 .transaction_depth
62 .min(self.returning_transaction_depth);
63 }
64}
65
66impl<'a> Transaction<'a> {
67 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
68 Transaction {
69 returning_transaction_depth: client.transaction_depth.saturating_sub(1),
70 client,
71 done: false,
72 }
73 }
74
75 pub async fn commit(mut self) -> Result<(), Error> {
77 self.done = true;
78 let query = if self.returning_transaction_depth > 0 {
79 format!("RELEASE sp_{}", self.returning_transaction_depth)
80 } else {
81 "COMMIT".to_string()
82 };
83 self.client.batch_execute(&query).await
84 }
85
86 pub async fn rollback(mut self) -> Result<(), Error> {
90 self.done = true;
91 let query = if self.returning_transaction_depth > 0 {
92 format!("ROLLBACK TO sp_{}", self.returning_transaction_depth)
93 } else {
94 "ROLLBACK".to_string()
95 };
96 self.client.batch_execute(&query).await?;
97 self.client.transaction_depth = self.returning_transaction_depth;
98 Ok(())
99 }
100
101 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
103 self.client.prepare(query).await
104 }
105
106 pub async fn prepare_typed(
108 &self,
109 query: &str,
110 parameter_types: &[Type],
111 ) -> Result<Statement, Error> {
112 self.client.prepare_typed(query, parameter_types).await
113 }
114
115 pub async fn query<T>(
117 &self,
118 statement: &T,
119 params: &[&(dyn ToSql + Sync)],
120 ) -> Result<Vec<Row>, Error>
121 where
122 T: ?Sized + ToStatement,
123 {
124 self.client.query(statement, params).await
125 }
126
127 pub async fn query_one<T>(
129 &self,
130 statement: &T,
131 params: &[&(dyn ToSql + Sync)],
132 ) -> Result<Row, Error>
133 where
134 T: ?Sized + ToStatement,
135 {
136 self.client.query_one(statement, params).await
137 }
138
139 pub async fn query_opt<T>(
141 &self,
142 statement: &T,
143 params: &[&(dyn ToSql + Sync)],
144 ) -> Result<Option<Row>, Error>
145 where
146 T: ?Sized + ToStatement,
147 {
148 self.client.query_opt(statement, params).await
149 }
150
151 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
153 where
154 T: ?Sized + ToStatement,
155 P: BorrowToSql,
156 I: IntoIterator<Item = P>,
157 I::IntoIter: ExactSizeIterator,
158 {
159 self.client.query_raw(statement, params).await
160 }
161
162 pub async fn execute<T>(
164 &self,
165 statement: &T,
166 params: &[&(dyn ToSql + Sync)],
167 ) -> Result<u64, Error>
168 where
169 T: ?Sized + ToStatement,
170 {
171 self.client.execute(statement, params).await
172 }
173
174 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
176 where
177 T: ?Sized + ToStatement,
178 P: BorrowToSql,
179 I: IntoIterator<Item = P>,
180 I::IntoIter: ExactSizeIterator,
181 {
182 self.client.execute_raw(statement, params).await
183 }
184
185 pub async fn bind<T>(
194 &self,
195 statement: &T,
196 params: &[&(dyn ToSql + Sync)],
197 ) -> Result<Portal, Error>
198 where
199 T: ?Sized + ToStatement,
200 {
201 self.bind_raw(statement, slice_iter(params)).await
202 }
203
204 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
208 where
209 T: ?Sized + ToStatement,
210 P: BorrowToSql,
211 I: IntoIterator<Item = P>,
212 I::IntoIter: ExactSizeIterator,
213 {
214 let statement = statement.__convert().into_statement(self.client).await?;
215 bind::bind(self.client.inner(), statement, params).await
216 }
217
218 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
223 self.query_portal_raw(portal, max_rows)
224 .await?
225 .try_collect()
226 .await
227 }
228
229 pub async fn query_portal_raw(
233 &self,
234 portal: &Portal,
235 max_rows: i32,
236 ) -> Result<RowStream, Error> {
237 query::query_portal(self.client.inner(), portal, max_rows).await
238 }
239
240 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
242 where
243 T: ?Sized + ToStatement,
244 U: Buf + 'static + Send,
245 {
246 self.client.copy_in(statement).await
247 }
248
249 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
251 where
252 T: ?Sized + ToStatement,
253 {
254 self.client.copy_out(statement).await
255 }
256
257 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
259 self.client.simple_query(query).await
260 }
261
262 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
264 self.client.batch_execute(query).await
265 }
266
267 pub fn cancel_token(&self) -> CancelToken {
269 self.client.cancel_token()
270 }
271
272 #[cfg(feature = "runtime")]
274 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
275 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
276 where
277 T: MakeTlsConnect<Socket>,
278 {
279 #[allow(deprecated)]
280 self.client.cancel_query(tls).await
281 }
282
283 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
285 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
286 where
287 S: AsyncRead + AsyncWrite + Unpin,
288 T: TlsConnect<S>,
289 {
290 #[allow(deprecated)]
291 self.client.cancel_query_raw(stream, tls).await
292 }
293
294 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
296 self._savepoint().await
297 }
298
299 pub async fn savepoint<I>(&mut self, _name: I) -> Result<Transaction<'_>, Error>
301 where
302 I: Into<String>,
303 {
304 self._savepoint().await
305 }
306
307 async fn _savepoint(&mut self) -> Result<Transaction<'_>, Error> {
308 self.client.transaction().await
309 }
310
311 pub fn client(&self) -> &Client {
313 self.client
314 }
315}