1use crate::codec::FrontendMessage;
2use crate::connection::RequestMessages;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{
12 bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
13 SimpleQueryMessage, Statement, ToStatement,
14};
15use bytes::Buf;
16use futures_util::TryStreamExt;
17use postgres_protocol::message::frontend;
18use tokio::io::{AsyncRead, AsyncWrite};
19
20pub struct Transaction<'a> {
25 client: &'a mut Client,
26 savepoint: Option<Savepoint>,
27 done: bool,
28}
29
30struct Savepoint {
32 name: String,
33 depth: u32,
34}
35
36impl<'a> Drop for Transaction<'a> {
37 fn drop(&mut self) {
38 if self.done {
39 return;
40 }
41
42 let query = if let Some(sp) = self.savepoint.as_ref() {
43 format!("ROLLBACK TO {}", sp.name)
44 } else {
45 "ROLLBACK".to_string()
46 };
47 let buf = self.client.inner().with_buf(|buf| {
48 frontend::query(&query, buf).unwrap();
49 buf.split().freeze()
50 });
51 let _ = self
52 .client
53 .inner()
54 .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
55 }
56}
57
58impl<'a> Transaction<'a> {
59 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
60 Transaction {
61 client,
62 savepoint: None,
63 done: false,
64 }
65 }
66
67 pub async fn commit(mut self) -> Result<(), Error> {
69 self.done = true;
70 let query = if let Some(sp) = self.savepoint.as_ref() {
71 format!("RELEASE {}", sp.name)
72 } else {
73 "COMMIT".to_string()
74 };
75 self.client.batch_execute(&query).await
76 }
77
78 pub async fn rollback(mut self) -> Result<(), Error> {
82 self.done = true;
83 let query = if let Some(sp) = self.savepoint.as_ref() {
84 format!("ROLLBACK TO {}", sp.name)
85 } else {
86 "ROLLBACK".to_string()
87 };
88 self.client.batch_execute(&query).await
89 }
90
91 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
93 self.client.prepare(query).await
94 }
95
96 pub async fn prepare_typed(
98 &self,
99 query: &str,
100 parameter_types: &[Type],
101 ) -> Result<Statement, Error> {
102 self.client.prepare_typed(query, parameter_types).await
103 }
104
105 pub async fn query<T>(
107 &self,
108 statement: &T,
109 params: &[&(dyn ToSql + Sync)],
110 ) -> Result<Vec<Row>, Error>
111 where
112 T: ?Sized + ToStatement,
113 {
114 self.client.query(statement, params).await
115 }
116
117 pub async fn query_one<T>(
119 &self,
120 statement: &T,
121 params: &[&(dyn ToSql + Sync)],
122 ) -> Result<Row, Error>
123 where
124 T: ?Sized + ToStatement,
125 {
126 self.client.query_one(statement, params).await
127 }
128
129 pub async fn query_opt<T>(
131 &self,
132 statement: &T,
133 params: &[&(dyn ToSql + Sync)],
134 ) -> Result<Option<Row>, Error>
135 where
136 T: ?Sized + ToStatement,
137 {
138 self.client.query_opt(statement, params).await
139 }
140
141 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
143 where
144 T: ?Sized + ToStatement,
145 P: BorrowToSql,
146 I: IntoIterator<Item = P>,
147 I::IntoIter: ExactSizeIterator,
148 {
149 self.client.query_raw(statement, params).await
150 }
151
152 pub async fn execute<T>(
154 &self,
155 statement: &T,
156 params: &[&(dyn ToSql + Sync)],
157 ) -> Result<u64, Error>
158 where
159 T: ?Sized + ToStatement,
160 {
161 self.client.execute(statement, params).await
162 }
163
164 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
166 where
167 T: ?Sized + ToStatement,
168 P: BorrowToSql,
169 I: IntoIterator<Item = P>,
170 I::IntoIter: ExactSizeIterator,
171 {
172 self.client.execute_raw(statement, params).await
173 }
174
175 pub async fn bind<T>(
184 &self,
185 statement: &T,
186 params: &[&(dyn ToSql + Sync)],
187 ) -> Result<Portal, Error>
188 where
189 T: ?Sized + ToStatement,
190 {
191 self.bind_raw(statement, slice_iter(params)).await
192 }
193
194 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
198 where
199 T: ?Sized + ToStatement,
200 P: BorrowToSql,
201 I: IntoIterator<Item = P>,
202 I::IntoIter: ExactSizeIterator,
203 {
204 let statement = statement.__convert().into_statement(self.client).await?;
205 bind::bind(self.client.inner(), statement, params).await
206 }
207
208 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
213 self.query_portal_raw(portal, max_rows)
214 .await?
215 .try_collect()
216 .await
217 }
218
219 pub async fn query_portal_raw(
223 &self,
224 portal: &Portal,
225 max_rows: i32,
226 ) -> Result<RowStream, Error> {
227 query::query_portal(self.client.inner(), portal, max_rows).await
228 }
229
230 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
232 where
233 T: ?Sized + ToStatement,
234 U: Buf + 'static + Send,
235 {
236 self.client.copy_in(statement).await
237 }
238
239 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
241 where
242 T: ?Sized + ToStatement,
243 {
244 self.client.copy_out(statement).await
245 }
246
247 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
249 self.client.simple_query(query).await
250 }
251
252 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
254 self.client.batch_execute(query).await
255 }
256
257 pub fn cancel_token(&self) -> CancelToken {
259 self.client.cancel_token()
260 }
261
262 #[cfg(feature = "runtime")]
264 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
265 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
266 where
267 T: MakeTlsConnect<Socket>,
268 {
269 #[allow(deprecated)]
270 self.client.cancel_query(tls).await
271 }
272
273 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
275 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
276 where
277 S: AsyncRead + AsyncWrite + Unpin,
278 T: TlsConnect<S>,
279 {
280 #[allow(deprecated)]
281 self.client.cancel_query_raw(stream, tls).await
282 }
283
284 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
286 self._savepoint(None).await
287 }
288
289 pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
291 where
292 I: Into<String>,
293 {
294 self._savepoint(Some(name.into())).await
295 }
296
297 async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
298 let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
299 let name = name.unwrap_or_else(|| format!("sp_{}", depth));
300 let query = format!("SAVEPOINT {}", name);
301 self.batch_execute(&query).await?;
302
303 Ok(Transaction {
304 client: self.client,
305 savepoint: Some(Savepoint { name, depth }),
306 done: false,
307 })
308 }
309
310 pub fn client(&self) -> &Client {
312 self.client
313 }
314}