1mod builder;
2mod portal;
3
4use std::borrow::Cow;
5
6use super::{
7 client::ClientBorrowMut,
8 driver::codec::{AsParams, Response, encode::Encode},
9 error::Error,
10 execute::Execute,
11 prepare::Prepare,
12 query::Query,
13 statement::Statement,
14 types::{Oid, ToSql, Type},
15};
16
17pub use builder::{IsolationLevel, TransactionBuilder};
18pub use portal::Portal;
19
20pub struct Transaction<C>
21where
22 C: Prepare + ClientBorrowMut,
23{
24 client: C,
25 save_point: SavePoint,
26 state: State,
27}
28
29enum SavePoint {
30 None,
31 Auto { depth: u32 },
32 Custom { name: String, depth: u32 },
33}
34
35impl SavePoint {
36 fn nest_save_point(&self, name: Option<String>) -> Self {
37 match *self {
38 Self::None => match name {
39 Some(name) => SavePoint::Custom { name, depth: 1 },
40 None => SavePoint::Auto { depth: 1 },
41 },
42 Self::Auto { depth } | Self::Custom { depth, .. } => match name {
43 Some(name) => SavePoint::Custom { name, depth },
44 None => SavePoint::Auto { depth: depth + 1 },
45 },
46 }
47 }
48
49 fn save_point_query(&self) -> Cow<'static, str> {
50 match self {
51 Self::None => Cow::Borrowed("SAVEPOINT"),
52 Self::Auto { depth } => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
53 Self::Custom { name, .. } => Cow::Owned(format!("SAVEPOINT {name}")),
54 }
55 }
56
57 fn commit_query(&self) -> Cow<'static, str> {
58 match self {
59 Self::None => Cow::Borrowed("COMMIT"),
60 Self::Auto { depth } => Cow::Owned(format!("RELEASE sp_{depth}")),
61 Self::Custom { name, .. } => Cow::Owned(format!("RELEASE {name}")),
62 }
63 }
64
65 fn rollback_query(&self) -> Cow<'static, str> {
66 match self {
67 Self::None => Cow::Borrowed("ROLLBACK"),
68 Self::Auto { depth } => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
69 Self::Custom { name, .. } => Cow::Owned(format!("ROLLBACK TO {name}")),
70 }
71 }
72}
73
74enum State {
75 WantRollback,
76 Finish,
77}
78
79impl<C> Drop for Transaction<C>
80where
81 C: Prepare + ClientBorrowMut,
82{
83 fn drop(&mut self) {
84 match self.state {
85 State::WantRollback => self.do_rollback(),
86 State::Finish => {}
87 }
88 }
89}
90
91impl<C> Transaction<C>
92where
93 C: Prepare + ClientBorrowMut,
94{
95 pub async fn bind<'p>(
100 &'p self,
101 statement: &'p Statement,
102 params: &[&(dyn ToSql + Sync)],
103 ) -> Result<Portal<'p, C>, Error> {
104 self.bind_raw(statement, params.iter().cloned()).await
105 }
106
107 pub async fn bind_raw<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
109 where
110 I: AsParams,
111 {
112 Portal::new(&self.client, statement, params).await
113 }
114
115 pub async fn transaction(&mut self) -> Result<Transaction<&mut C>, Error> {
119 self._save_point(None).await
120 }
121
122 pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<&mut C>, Error>
126 where
127 I: Into<String>,
128 {
129 self._save_point(Some(name.into())).await
130 }
131
132 pub async fn commit(mut self) -> Result<(), Error> {
134 self.state = State::Finish;
135 self.save_point.commit_query().execute(&self).await?;
136 Ok(())
137 }
138
139 pub async fn rollback(mut self) -> Result<(), Error> {
143 self.state = State::Finish;
144 self.save_point.rollback_query().execute(&self).await?;
145 Ok(())
146 }
147
148 fn new(client: C) -> Transaction<C> {
149 Transaction {
150 client,
151 save_point: SavePoint::None,
152 state: State::WantRollback,
153 }
154 }
155
156 async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<&mut C>, Error> {
157 let save_point = self.save_point.nest_save_point(name);
158 save_point.save_point_query().execute(self).await?;
159
160 Ok(Transaction {
161 client: &mut self.client,
162 save_point,
163 state: State::WantRollback,
164 })
165 }
166
167 fn do_rollback(&mut self) {
168 drop(self.save_point.rollback_query().execute(self));
169 }
170}
171
172impl<C> Prepare for Transaction<C>
173where
174 C: Prepare + ClientBorrowMut,
175{
176 #[inline]
177 async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
178 self.client._get_type(oid).await
179 }
180
181 #[inline]
182 fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
183 self.client._get_type_blocking(oid)
184 }
185}
186
187impl<C> Query for Transaction<C>
188where
189 C: Prepare + ClientBorrowMut,
190{
191 #[inline]
192 fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
193 where
194 S: Encode,
195 {
196 self.client._send_encode_query(stmt)
197 }
198}