1#[cfg(feature = "runtime")]
2use crate::Socket;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9use crate::{
10 CancelToken, Client, CopyInSink, Error, Portal, Row, SimpleQueryMessage, Statement,
11 ToStatement, bind, query, slice_iter,
12};
13use bytes::Buf;
14use futures_util::TryStreamExt;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17pub struct Transaction<'a> {
22 client: &'a mut Client,
23 savepoint: Option<Savepoint>,
24 done: bool,
25}
26
27struct Savepoint {
29 name: String,
30 depth: u32,
31}
32
33impl Drop for Transaction<'_> {
34 fn drop(&mut self) {
35 if self.done {
36 return;
37 }
38
39 let name = self.savepoint.as_ref().map(|sp| sp.name.as_str());
40 self.client.__private_api_rollback(name);
41 }
42}
43
44impl<'a> Transaction<'a> {
45 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
46 Transaction {
47 client,
48 savepoint: None,
49 done: false,
50 }
51 }
52
53 pub async fn commit(mut self) -> Result<(), Error> {
55 self.done = true;
56 let query = if let Some(sp) = self.savepoint.as_ref() {
57 format!("RELEASE {}", sp.name)
58 } else {
59 "COMMIT".to_string()
60 };
61 self.client.batch_execute(&query).await
62 }
63
64 pub async fn rollback(mut self) -> Result<(), Error> {
68 self.done = true;
69 let query = if let Some(sp) = self.savepoint.as_ref() {
70 format!("ROLLBACK TO {}", sp.name)
71 } else {
72 "ROLLBACK".to_string()
73 };
74 self.client.batch_execute(&query).await
75 }
76
77 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
79 self.client.prepare(query).await
80 }
81
82 pub async fn prepare_typed(
84 &self,
85 query: &str,
86 parameter_types: &[Type],
87 ) -> Result<Statement, Error> {
88 self.client.prepare_typed(query, parameter_types).await
89 }
90
91 pub async fn query<T>(
93 &self,
94 statement: &T,
95 params: &[&(dyn ToSql + Sync)],
96 ) -> Result<Vec<Row>, Error>
97 where
98 T: ?Sized + ToStatement,
99 {
100 self.client.query(statement, params).await
101 }
102
103 pub async fn query_one<T>(
105 &self,
106 statement: &T,
107 params: &[&(dyn ToSql + Sync)],
108 ) -> Result<Row, Error>
109 where
110 T: ?Sized + ToStatement,
111 {
112 self.client.query_one(statement, params).await
113 }
114
115 pub async fn query_opt<T>(
117 &self,
118 statement: &T,
119 params: &[&(dyn ToSql + Sync)],
120 ) -> Result<Option<Row>, Error>
121 where
122 T: ?Sized + ToStatement,
123 {
124 self.client.query_opt(statement, params).await
125 }
126
127 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
129 where
130 T: ?Sized + ToStatement,
131 P: BorrowToSql,
132 I: IntoIterator<Item = P>,
133 I::IntoIter: ExactSizeIterator,
134 {
135 self.client.query_raw(statement, params).await
136 }
137
138 pub async fn query_typed(
140 &self,
141 statement: &str,
142 params: &[(&(dyn ToSql + Sync), Type)],
143 ) -> Result<Vec<Row>, Error> {
144 self.client.query_typed(statement, params).await
145 }
146
147 pub async fn query_typed_one(
149 &self,
150 statement: &str,
151 params: &[(&(dyn ToSql + Sync), Type)],
152 ) -> Result<Row, Error> {
153 self.client.query_typed_one(statement, params).await
154 }
155
156 pub async fn query_typed_opt(
158 &self,
159 statement: &str,
160 params: &[(&(dyn ToSql + Sync), Type)],
161 ) -> Result<Option<Row>, Error> {
162 self.client.query_typed_opt(statement, params).await
163 }
164
165 pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
167 where
168 P: BorrowToSql,
169 I: IntoIterator<Item = (P, Type)>,
170 {
171 self.client.query_typed_raw(query, params).await
172 }
173
174 pub async fn execute<T>(
176 &self,
177 statement: &T,
178 params: &[&(dyn ToSql + Sync)],
179 ) -> Result<u64, Error>
180 where
181 T: ?Sized + ToStatement,
182 {
183 self.client.execute(statement, params).await
184 }
185
186 pub async fn execute_typed(
188 &self,
189 statement: &str,
190 params: &[(&(dyn ToSql + Sync), Type)],
191 ) -> Result<u64, Error> {
192 self.client.execute_typed(statement, params).await
193 }
194
195 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
197 where
198 T: ?Sized + ToStatement,
199 P: BorrowToSql,
200 I: IntoIterator<Item = P>,
201 I::IntoIter: ExactSizeIterator,
202 {
203 self.client.execute_raw(statement, params).await
204 }
205
206 pub async fn bind<T>(
215 &self,
216 statement: &T,
217 params: &[&(dyn ToSql + Sync)],
218 ) -> Result<Portal, Error>
219 where
220 T: ?Sized + ToStatement,
221 {
222 self.bind_raw(statement, slice_iter(params)).await
223 }
224
225 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
229 where
230 T: ?Sized + ToStatement,
231 P: BorrowToSql,
232 I: IntoIterator<Item = P>,
233 I::IntoIter: ExactSizeIterator,
234 {
235 let statement = statement
236 .__convert()
237 .into_statement(self.client.inner())
238 .await?;
239 bind::bind(self.client.inner(), statement, params).await
240 }
241
242 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
247 self.query_portal_raw(portal, max_rows)
248 .await?
249 .try_collect()
250 .await
251 }
252
253 pub async fn query_portal_raw(
257 &self,
258 portal: &Portal,
259 max_rows: i32,
260 ) -> Result<RowStream, Error> {
261 query::query_portal(self.client.inner(), portal, max_rows).await
262 }
263
264 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
266 where
267 T: ?Sized + ToStatement,
268 U: Buf + 'static + Send,
269 {
270 self.client.copy_in(statement).await
271 }
272
273 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
275 where
276 T: ?Sized + ToStatement,
277 {
278 self.client.copy_out(statement).await
279 }
280
281 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
283 self.client.simple_query(query).await
284 }
285
286 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
288 self.client.batch_execute(query).await
289 }
290
291 pub fn cancel_token(&self) -> CancelToken {
293 self.client.cancel_token()
294 }
295
296 #[cfg(feature = "runtime")]
298 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
299 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
300 where
301 T: MakeTlsConnect<Socket>,
302 {
303 #[allow(deprecated)]
304 self.client.cancel_query(tls).await
305 }
306
307 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
309 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
310 where
311 S: AsyncRead + AsyncWrite + Unpin,
312 T: TlsConnect<S>,
313 {
314 #[allow(deprecated)]
315 self.client.cancel_query_raw(stream, tls).await
316 }
317
318 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
320 self._savepoint(None).await
321 }
322
323 pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
325 where
326 I: Into<String>,
327 {
328 self._savepoint(Some(name.into())).await
329 }
330
331 async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
332 let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
333 let name = name.unwrap_or_else(|| format!("sp_{depth}"));
334 let query = format!("SAVEPOINT {name}");
335 self.batch_execute(&query).await?;
336
337 Ok(Transaction {
338 client: self.client,
339 savepoint: Some(Savepoint { name, depth }),
340 done: false,
341 })
342 }
343
344 pub fn client(&self) -> &Client {
346 self.client
347 }
348}