1use crate::copy_out::CopyOutStream;
2use crate::query::RowStream;
3#[cfg(feature = "runtime")]
4use crate::tls::MakeTlsConnect;
5use crate::tls::TlsConnect;
6use crate::types::{BorrowToSql, ToSql, Type};
7#[cfg(feature = "runtime")]
8use crate::Socket;
9use crate::{
10 bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
11 SimpleQueryMessage, Statement, ToStatement,
12};
13use bytes::Buf;
14use futures_util::TryStreamExt;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17pub struct Transaction<'a> {
22 client: &'a mut Client,
23 savepoint: Option<Savepoint>,
24 done: bool,
25}
26
27struct Savepoint {
29 name: String,
30 depth: u32,
31}
32
33impl Drop for Transaction<'_> {
34 fn drop(&mut self) {
35 if self.done {
36 return;
37 }
38
39 let name = self.savepoint.as_ref().map(|sp| sp.name.as_str());
40 self.client.__private_api_rollback(name);
41 }
42}
43
44impl<'a> Transaction<'a> {
45 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
46 Transaction {
47 client,
48 savepoint: None,
49 done: false,
50 }
51 }
52
53 pub async fn commit(mut self) -> Result<(), Error> {
55 self.done = true;
56 let query = if let Some(sp) = self.savepoint.as_ref() {
57 format!("RELEASE {}", sp.name)
58 } else {
59 "COMMIT".to_string()
60 };
61 self.client.batch_execute(&query).await
62 }
63
64 pub async fn rollback(mut self) -> Result<(), Error> {
68 self.done = true;
69 let query = if let Some(sp) = self.savepoint.as_ref() {
70 format!("ROLLBACK TO {}", sp.name)
71 } else {
72 "ROLLBACK".to_string()
73 };
74 self.client.batch_execute(&query).await
75 }
76
77 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
79 self.client.prepare(query).await
80 }
81
82 pub async fn prepare_typed(
84 &self,
85 query: &str,
86 parameter_types: &[Type],
87 ) -> Result<Statement, Error> {
88 self.client.prepare_typed(query, parameter_types).await
89 }
90
91 pub async fn query<T>(
93 &self,
94 statement: &T,
95 params: &[&(dyn ToSql + Sync)],
96 ) -> Result<Vec<Row>, Error>
97 where
98 T: ?Sized + ToStatement,
99 {
100 self.client.query(statement, params).await
101 }
102
103 pub async fn query_one<T>(
105 &self,
106 statement: &T,
107 params: &[&(dyn ToSql + Sync)],
108 ) -> Result<Row, Error>
109 where
110 T: ?Sized + ToStatement,
111 {
112 self.client.query_one(statement, params).await
113 }
114
115 pub async fn query_opt<T>(
117 &self,
118 statement: &T,
119 params: &[&(dyn ToSql + Sync)],
120 ) -> Result<Option<Row>, Error>
121 where
122 T: ?Sized + ToStatement,
123 {
124 self.client.query_opt(statement, params).await
125 }
126
127 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
129 where
130 T: ?Sized + ToStatement,
131 P: BorrowToSql,
132 I: IntoIterator<Item = P>,
133 I::IntoIter: ExactSizeIterator,
134 {
135 self.client.query_raw(statement, params).await
136 }
137
138 pub async fn query_typed(
140 &self,
141 statement: &str,
142 params: &[(&(dyn ToSql + Sync), Type)],
143 ) -> Result<Vec<Row>, Error> {
144 self.client.query_typed(statement, params).await
145 }
146
147 pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
149 where
150 P: BorrowToSql,
151 I: IntoIterator<Item = (P, Type)>,
152 {
153 self.client.query_typed_raw(query, params).await
154 }
155
156 pub async fn execute<T>(
158 &self,
159 statement: &T,
160 params: &[&(dyn ToSql + Sync)],
161 ) -> Result<u64, Error>
162 where
163 T: ?Sized + ToStatement,
164 {
165 self.client.execute(statement, params).await
166 }
167
168 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
170 where
171 T: ?Sized + ToStatement,
172 P: BorrowToSql,
173 I: IntoIterator<Item = P>,
174 I::IntoIter: ExactSizeIterator,
175 {
176 self.client.execute_raw(statement, params).await
177 }
178
179 pub async fn bind<T>(
188 &self,
189 statement: &T,
190 params: &[&(dyn ToSql + Sync)],
191 ) -> Result<Portal, Error>
192 where
193 T: ?Sized + ToStatement,
194 {
195 self.bind_raw(statement, slice_iter(params)).await
196 }
197
198 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
202 where
203 T: ?Sized + ToStatement,
204 P: BorrowToSql,
205 I: IntoIterator<Item = P>,
206 I::IntoIter: ExactSizeIterator,
207 {
208 let statement = statement.__convert().into_statement(self.client).await?;
209 bind::bind(self.client.inner(), statement, params).await
210 }
211
212 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
217 self.query_portal_raw(portal, max_rows)
218 .await?
219 .try_collect()
220 .await
221 }
222
223 pub async fn query_portal_raw(
227 &self,
228 portal: &Portal,
229 max_rows: i32,
230 ) -> Result<RowStream, Error> {
231 query::query_portal(self.client.inner(), portal, max_rows).await
232 }
233
234 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
236 where
237 T: ?Sized + ToStatement,
238 U: Buf + 'static + Send,
239 {
240 self.client.copy_in(statement).await
241 }
242
243 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
245 where
246 T: ?Sized + ToStatement,
247 {
248 self.client.copy_out(statement).await
249 }
250
251 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
253 self.client.simple_query(query).await
254 }
255
256 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
258 self.client.batch_execute(query).await
259 }
260
261 pub fn cancel_token(&self) -> CancelToken {
263 self.client.cancel_token()
264 }
265
266 #[cfg(feature = "runtime")]
268 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
269 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
270 where
271 T: MakeTlsConnect<Socket>,
272 {
273 #[allow(deprecated)]
274 self.client.cancel_query(tls).await
275 }
276
277 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
279 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
280 where
281 S: AsyncRead + AsyncWrite + Unpin,
282 T: TlsConnect<S>,
283 {
284 #[allow(deprecated)]
285 self.client.cancel_query_raw(stream, tls).await
286 }
287
288 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
290 self._savepoint(None).await
291 }
292
293 pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
295 where
296 I: Into<String>,
297 {
298 self._savepoint(Some(name.into())).await
299 }
300
301 async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
302 let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
303 let name = name.unwrap_or_else(|| format!("sp_{depth}"));
304 let query = format!("SAVEPOINT {name}");
305 self.batch_execute(&query).await?;
306
307 Ok(Transaction {
308 client: self.client,
309 savepoint: Some(Savepoint { name, depth }),
310 done: false,
311 })
312 }
313
314 pub fn client(&self) -> &Client {
316 self.client
317 }
318}