1mod builder;
2mod portal;
3
4use core::ops::{Deref, DerefMut};
5
6use std::borrow::Cow;
7
8use super::{
9 client::{Client, ClientBorrowMut},
10 driver::codec::AsParams,
11 error::Error,
12 execute::Execute,
13 pool::PoolConnection,
14 statement::Statement,
15 types::ToSql,
16};
17
18pub use builder::{IsolationLevel, TransactionBuilder};
19pub use portal::Portal;
20
21struct SavePoint {
22 name: Option<String>,
23 depth: u32,
24 state: State,
25}
26
27enum State {
28 WantRollback,
29 Finish,
30}
31
32impl Default for SavePoint {
33 fn default() -> Self {
34 Self {
35 name: None,
36 depth: 0,
37 state: State::WantRollback,
38 }
39 }
40}
41
42impl SavePoint {
43 fn rollback(&mut self, cli: impl ClientBorrowMut) -> impl Future<Output = Result<(), Error>> + Send {
45 self.state = State::Finish;
46
47 let fut = match self.depth {
48 0 => Cow::Borrowed("ROLLBACK"),
49 depth => match self.name {
50 None => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
51 Some(ref name) => Cow::Owned(format!("ROLLBACK TO {name}")),
52 },
53 }
54 .execute(cli.borrow_cli_ref());
55
56 async { fut.await.map(|_| ()) }
57 }
58
59 async fn commit(&mut self, cli: impl ClientBorrowMut) -> Result<(), Error> {
60 self.state = State::Finish;
61
62 match self.depth {
63 0 => Cow::Borrowed("COMMIT"),
64 depth => match self.name {
65 None => Cow::Owned(format!("RELEASE sp_{depth}")),
66 Some(ref name) => Cow::Owned(format!("RELEASE {name}")),
67 },
68 }
69 .execute(cli.borrow_cli_ref())
70 .await
71 .map(|_| ())
72 }
73
74 async fn nest_save_point(&self, cli: impl ClientBorrowMut, name: Option<String>) -> Result<Self, Error> {
75 let depth = self.depth + 1;
76
77 match self.depth {
78 0 => match name {
79 Some(ref name) => Cow::Owned(format!("SAVEPOINT {name}")),
80 None => Cow::Borrowed("SAVEPOINT sp_1"),
81 },
82 depth => match name {
83 Some(ref name) => Cow::Owned(format!("SAVEPOINT {name}")),
84 None => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
85 },
86 }
87 .execute(cli.borrow_cli_ref())
88 .await
89 .map(|_| SavePoint {
90 name,
91 depth,
92 state: State::WantRollback,
93 })
94 }
95
96 fn on_drop(&mut self, cli: impl ClientBorrowMut) {
97 if matches!(self.state, State::WantRollback) {
98 drop(self.rollback(cli));
99 }
100 }
101}
102
103pub struct Transaction<'a, C>
104where
105 C: ClientBorrowMut,
106{
107 client: _Client<'a, C>,
108 save_point: SavePoint,
109}
110
111enum _Client<'a, C> {
112 Owned(C),
113 Borrowed(&'a mut C),
114}
115
116impl<C> _Client<'_, C> {
117 #[inline]
118 fn reborrow(&mut self) -> _Client<'_, C> {
119 _Client::Borrowed(self.deref_mut())
120 }
121}
122
123impl<C> Deref for _Client<'_, C> {
124 type Target = C;
125
126 fn deref(&self) -> &Self::Target {
127 match self {
128 Self::Borrowed(c) => c,
129 Self::Owned(c) => c,
130 }
131 }
132}
133
134impl<C> DerefMut for _Client<'_, C> {
135 fn deref_mut(&mut self) -> &mut Self::Target {
136 match self {
137 Self::Borrowed(c) => c,
138 Self::Owned(c) => c,
139 }
140 }
141}
142
143impl<C> Drop for Transaction<'_, C>
144where
145 C: ClientBorrowMut,
146{
147 fn drop(&mut self) {
148 self.save_point.on_drop(self.client.deref_mut());
149 }
150}
151
152impl<C> Transaction<'_, C>
153where
154 C: ClientBorrowMut,
155{
156 pub async fn bind<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
158 where
159 I: AsParams,
160 {
161 Portal::new(&*self.client, statement, params).await
162 }
163
164 pub async fn bind_dyn<'p>(
169 &'p self,
170 statement: &'p Statement,
171 params: &[&(dyn ToSql + Sync)],
172 ) -> Result<Portal<'p, C>, Error> {
173 self.bind(statement, params.iter().cloned()).await
174 }
175
176 pub async fn transaction(&mut self) -> Result<Transaction<'_, C>, Error> {
180 self._save_point(None).await
181 }
182
183 pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<'_, C>, Error>
187 where
188 I: Into<String>,
189 {
190 self._save_point(Some(name.into())).await
191 }
192
193 pub async fn commit(mut self) -> Result<(), Error> {
195 self.save_point.commit(self.client.deref_mut()).await
196 }
197
198 pub async fn rollback(mut self) -> Result<(), Error> {
202 self.save_point.rollback(self.client.deref_mut()).await
203 }
204
205 fn new(client: &mut C) -> Transaction<'_, C> {
206 Transaction {
207 client: _Client::Borrowed(client),
208 save_point: SavePoint::default(),
209 }
210 }
211
212 fn new_owned<'a>(client: C) -> Transaction<'a, C>
213 where
214 C: 'a,
215 {
216 Transaction {
217 client: _Client::Owned(client),
218 save_point: SavePoint::default(),
219 }
220 }
221
222 async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<'_, C>, Error> {
223 let save_point = self.save_point.nest_save_point(self.client.deref_mut(), name).await?;
224 Ok(Transaction {
225 client: self.client.reborrow(),
226 save_point,
227 })
228 }
229}
230
231impl<'c, C, Q, EO, QO> Execute<&'c Transaction<'_, C>> for Q
232where
233 C: ClientBorrowMut,
234 Q: Execute<&'c Client, ExecuteOutput = EO, QueryOutput = QO>,
235{
236 type ExecuteOutput = EO;
237 type QueryOutput = QO;
238
239 #[inline]
240 fn execute(self, cli: &'c Transaction<'_, C>) -> Self::ExecuteOutput {
241 Q::execute(self, cli.client.borrow_cli_ref())
242 }
243
244 #[inline]
245 fn query(self, cli: &'c Transaction<C>) -> Self::QueryOutput {
246 Q::query(self, cli.client.borrow_cli_ref())
247 }
248}
249
250impl<'c, 'p, Q, EO, QO> Execute<&'c mut Transaction<'_, PoolConnection<'p>>> for Q
252where
253 Q: Execute<&'c mut PoolConnection<'p>, ExecuteOutput = EO, QueryOutput = QO>,
254{
255 type ExecuteOutput = EO;
256 type QueryOutput = QO;
257
258 #[inline]
259 fn execute(self, cli: &'c mut Transaction<'_, PoolConnection<'p>>) -> Self::ExecuteOutput {
260 Q::execute(self, cli.client.deref_mut())
261 }
262
263 #[inline]
264 fn query(self, cli: &'c mut Transaction<PoolConnection<'p>>) -> Self::QueryOutput {
265 Q::query(self, cli.client.deref_mut())
266 }
267}