1mod builder;
2mod portal;
3
4use std::borrow::Cow;
5
6use super::{
7 client::ClientBorrowMut,
8 driver::codec::{encode::Encode, AsParams, Response},
9 error::Error,
10 execute::Execute,
11 prepare::Prepare,
12 query::Query,
13 statement::Statement,
14 types::{Oid, ToSql, Type},
15 BoxedFuture,
16};
17
18pub use builder::TransactionBuilder;
19pub use portal::Portal;
20
21pub struct Transaction<'a, C>
22where
23 C: Prepare + ClientBorrowMut,
24{
25 client: &'a mut C,
26 save_point: SavePoint,
27 state: State,
28}
29
30enum SavePoint {
31 None,
32 Auto { depth: u32 },
33 Custom { name: String, depth: u32 },
34}
35
36impl SavePoint {
37 fn nest_save_point(&self, name: Option<String>) -> Self {
38 match *self {
39 Self::None => match name {
40 Some(name) => SavePoint::Custom { name, depth: 1 },
41 None => SavePoint::Auto { depth: 1 },
42 },
43 Self::Auto { depth } | Self::Custom { depth, .. } => match name {
44 Some(name) => SavePoint::Custom { name, depth },
45 None => SavePoint::Auto { depth: depth + 1 },
46 },
47 }
48 }
49
50 fn save_point_query(&self) -> Cow<'static, str> {
51 match self {
52 Self::None => Cow::Borrowed("SAVEPOINT"),
53 Self::Auto { depth } => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
54 Self::Custom { name, .. } => Cow::Owned(format!("SAVEPOINT {name}")),
55 }
56 }
57
58 fn commit_query(&self) -> Cow<'static, str> {
59 match self {
60 Self::None => Cow::Borrowed("COMMIT"),
61 Self::Auto { depth } => Cow::Owned(format!("RELEASE sp_{depth}")),
62 Self::Custom { name, .. } => Cow::Owned(format!("RELEASE {name}")),
63 }
64 }
65
66 fn rollback_query(&self) -> Cow<'static, str> {
67 match self {
68 Self::None => Cow::Borrowed("ROLLBACK"),
69 Self::Auto { depth } => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
70 Self::Custom { name, .. } => Cow::Owned(format!("ROLLBACK TO {name}")),
71 }
72 }
73}
74
75enum State {
76 WantRollback,
77 Finish,
78}
79
80impl<C> Drop for Transaction<'_, C>
81where
82 C: Prepare + ClientBorrowMut,
83{
84 fn drop(&mut self) {
85 match self.state {
86 State::WantRollback => self.do_rollback(),
87 State::Finish => {}
88 }
89 }
90}
91
92impl<C> Transaction<'_, C>
93where
94 C: Prepare + ClientBorrowMut,
95{
96 pub fn builder() -> TransactionBuilder {
97 TransactionBuilder::new()
98 }
99
100 pub async fn bind<'p>(
105 &'p self,
106 statement: &'p Statement,
107 params: &[&(dyn ToSql + Sync)],
108 ) -> Result<Portal<'p, C>, Error> {
109 self.bind_raw(statement, params.iter().cloned()).await
110 }
111
112 pub async fn bind_raw<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
114 where
115 I: AsParams,
116 {
117 Portal::new(self.client, statement, params).await
118 }
119
120 pub async fn transaction(&mut self) -> Result<Transaction<C>, Error> {
124 self._save_point(None).await
125 }
126
127 pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<C>, Error>
131 where
132 I: Into<String>,
133 {
134 self._save_point(Some(name.into())).await
135 }
136
137 pub async fn commit(mut self) -> Result<(), Error> {
139 self.state = State::Finish;
140 self.save_point.commit_query().execute(&self).await?;
141 Ok(())
142 }
143
144 pub async fn rollback(mut self) -> Result<(), Error> {
148 self.state = State::Finish;
149 self.save_point.rollback_query().execute(&self).await?;
150 Ok(())
151 }
152
153 fn new(client: &mut C) -> Transaction<C> {
154 Transaction {
155 client,
156 save_point: SavePoint::None,
157 state: State::WantRollback,
158 }
159 }
160
161 async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<C>, Error> {
162 let save_point = self.save_point.nest_save_point(name);
163 save_point.save_point_query().execute(self).await?;
164
165 Ok(Transaction {
166 client: self.client,
167 save_point,
168 state: State::WantRollback,
169 })
170 }
171
172 fn do_rollback(&mut self) {
173 drop(self.save_point.rollback_query().execute(self));
174 }
175}
176
177impl<C> Prepare for Transaction<'_, C>
178where
179 C: Prepare + ClientBorrowMut,
180{
181 #[inline]
182 fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
183 self.client._get_type(oid)
184 }
185
186 #[inline]
187 fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
188 self.client._get_type_blocking(oid)
189 }
190}
191
192impl<C> Query for Transaction<'_, C>
193where
194 C: Prepare + ClientBorrowMut,
195{
196 #[inline]
197 fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
198 where
199 S: Encode,
200 {
201 self.client._send_encode_query(stmt)
202 }
203}