1use crate::connection::ConnectionRef;
2use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
3use yb_tokio_postgres::types::{BorrowToSql, ToSql, Type};
4use yb_tokio_postgres::{Error, Row, SimpleQueryMessage};
5
6pub struct Transaction<'a> {
11 connection: ConnectionRef<'a>,
12 transaction: Option<yb_tokio_postgres::Transaction<'a>>,
13}
14
15impl<'a> Drop for Transaction<'a> {
16 fn drop(&mut self) {
17 if let Some(transaction) = self.transaction.take() {
18 let _ = self.connection.block_on(transaction.rollback());
19 }
20 }
21}
22
23impl<'a> Transaction<'a> {
24 pub(crate) fn new(
25 connection: ConnectionRef<'a>,
26 transaction: yb_tokio_postgres::Transaction<'a>,
27 ) -> Transaction<'a> {
28 Transaction {
29 connection,
30 transaction: Some(transaction),
31 }
32 }
33
34 pub fn commit(mut self) -> Result<(), Error> {
36 self.connection
37 .block_on(self.transaction.take().unwrap().commit())
38 }
39
40 pub fn rollback(mut self) -> Result<(), Error> {
44 self.connection
45 .block_on(self.transaction.take().unwrap().rollback())
46 }
47
48 pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
50 self.connection
51 .block_on(self.transaction.as_ref().unwrap().prepare(query))
52 }
53
54 pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
56 self.connection.block_on(
57 self.transaction
58 .as_ref()
59 .unwrap()
60 .prepare_typed(query, types),
61 )
62 }
63
64 pub fn execute<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
66 where
67 T: ?Sized + ToStatement,
68 {
69 self.connection
70 .block_on(self.transaction.as_ref().unwrap().execute(query, params))
71 }
72
73 pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
75 where
76 T: ?Sized + ToStatement,
77 {
78 self.connection
79 .block_on(self.transaction.as_ref().unwrap().query(query, params))
80 }
81
82 pub fn query_one<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error>
84 where
85 T: ?Sized + ToStatement,
86 {
87 self.connection
88 .block_on(self.transaction.as_ref().unwrap().query_one(query, params))
89 }
90
91 pub fn query_opt<T>(
93 &mut self,
94 query: &T,
95 params: &[&(dyn ToSql + Sync)],
96 ) -> Result<Option<Row>, Error>
97 where
98 T: ?Sized + ToStatement,
99 {
100 self.connection
101 .block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
102 }
103
104 pub fn query_raw<T, P, I>(&mut self, query: &T, params: I) -> Result<RowIter<'_>, Error>
106 where
107 T: ?Sized + ToStatement,
108 P: BorrowToSql,
109 I: IntoIterator<Item = P>,
110 I::IntoIter: ExactSizeIterator,
111 {
112 let stream = self
113 .connection
114 .block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
115 Ok(RowIter::new(self.connection.as_ref(), stream))
116 }
117
118 pub fn bind<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Portal, Error>
129 where
130 T: ?Sized + ToStatement,
131 {
132 self.connection
133 .block_on(self.transaction.as_ref().unwrap().bind(query, params))
134 }
135
136 pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
141 self.connection.block_on(
142 self.transaction
143 .as_ref()
144 .unwrap()
145 .query_portal(portal, max_rows),
146 )
147 }
148
149 pub fn query_portal_raw(
151 &mut self,
152 portal: &Portal,
153 max_rows: i32,
154 ) -> Result<RowIter<'_>, Error> {
155 let stream = self.connection.block_on(
156 self.transaction
157 .as_ref()
158 .unwrap()
159 .query_portal_raw(portal, max_rows),
160 )?;
161 Ok(RowIter::new(self.connection.as_ref(), stream))
162 }
163
164 pub fn copy_in<T>(&mut self, query: &T) -> Result<CopyInWriter<'_>, Error>
166 where
167 T: ?Sized + ToStatement,
168 {
169 let sink = self
170 .connection
171 .block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
172 Ok(CopyInWriter::new(self.connection.as_ref(), sink))
173 }
174
175 pub fn copy_out<T>(&mut self, query: &T) -> Result<CopyOutReader<'_>, Error>
177 where
178 T: ?Sized + ToStatement,
179 {
180 let stream = self
181 .connection
182 .block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
183 Ok(CopyOutReader::new(self.connection.as_ref(), stream))
184 }
185
186 pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
188 self.connection
189 .block_on(self.transaction.as_ref().unwrap().simple_query(query))
190 }
191
192 pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
194 self.connection
195 .block_on(self.transaction.as_ref().unwrap().batch_execute(query))
196 }
197
198 pub fn cancel_token(&self) -> CancelToken {
200 CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
201 }
202
203 pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
205 let transaction = self
206 .connection
207 .block_on(self.transaction.as_mut().unwrap().transaction())?;
208 Ok(Transaction::new(self.connection.as_ref(), transaction))
209 }
210
211 pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
213 where
214 I: Into<String>,
215 {
216 let transaction = self
217 .connection
218 .block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
219 Ok(Transaction::new(self.connection.as_ref(), transaction))
220 }
221}