1use std::borrow::Cow;
2use std::fmt::{self, Debug, Formatter};
3use std::ops::{Deref, DerefMut};
4
5use futures_core::future::BoxFuture;
6
7use crate::database::Database;
8use crate::error::Error;
9use crate::pool::MaybePoolConnection;
10
11#[doc(hidden)]
15pub trait TransactionManager {
16 type Database: Database;
17
18 fn begin(
20 conn: &mut <Self::Database as Database>::Connection,
21 ) -> BoxFuture<'_, Result<(), Error>>;
22
23 fn commit(
25 conn: &mut <Self::Database as Database>::Connection,
26 ) -> BoxFuture<'_, Result<(), Error>>;
27
28 fn rollback(
30 conn: &mut <Self::Database as Database>::Connection,
31 ) -> BoxFuture<'_, Result<(), Error>>;
32
33 fn start_rollback(conn: &mut <Self::Database as Database>::Connection);
35}
36
37pub struct Transaction<'c, DB>
54where
55 DB: Database,
56{
57 connection: MaybePoolConnection<'c, DB>,
58 open: bool,
59}
60
61impl<'c, DB> Transaction<'c, DB>
62where
63 DB: Database,
64{
65 pub(crate) fn begin(
66 conn: impl Into<MaybePoolConnection<'c, DB>>,
67 ) -> BoxFuture<'c, Result<Self, Error>> {
68 let mut conn = conn.into();
69
70 Box::pin(async move {
71 DB::TransactionManager::begin(&mut conn).await?;
72
73 Ok(Self {
74 connection: conn,
75 open: true,
76 })
77 })
78 }
79
80 pub async fn commit(mut self) -> Result<(), Error> {
82 DB::TransactionManager::commit(&mut self.connection).await?;
83 self.open = false;
84
85 Ok(())
86 }
87
88 pub async fn rollback(mut self) -> Result<(), Error> {
90 DB::TransactionManager::rollback(&mut self.connection).await?;
91 self.open = false;
92
93 Ok(())
94 }
95}
96
97#[allow(unused_macros)]
99macro_rules! impl_executor_for_transaction {
100 ($DB:ident, $Row:ident) => {
101 impl<'c, 't> crate::executor::Executor<'t>
102 for &'t mut crate::transaction::Transaction<'c, $DB>
103 {
104 type Database = $DB;
105
106 fn fetch_many<'e, 'q: 'e, E: 'q>(
107 self,
108 query: E,
109 ) -> futures_core::stream::BoxStream<
110 'e,
111 Result<
112 either::Either<<$DB as crate::database::Database>::QueryResult, $Row>,
113 crate::error::Error,
114 >,
115 >
116 where
117 't: 'e,
118 E: crate::executor::Execute<'q, Self::Database>,
119 {
120 (&mut **self).fetch_many(query)
121 }
122
123 fn fetch_optional<'e, 'q: 'e, E: 'q>(
124 self,
125 query: E,
126 ) -> futures_core::future::BoxFuture<'e, Result<Option<$Row>, crate::error::Error>>
127 where
128 't: 'e,
129 E: crate::executor::Execute<'q, Self::Database>,
130 {
131 (&mut **self).fetch_optional(query)
132 }
133
134 fn prepare_with<'e, 'q: 'e>(
135 self,
136 sql: &'q str,
137 parameters: &'e [<Self::Database as crate::database::Database>::TypeInfo],
138 ) -> futures_core::future::BoxFuture<
139 'e,
140 Result<
141 <Self::Database as crate::database::HasStatement<'q>>::Statement,
142 crate::error::Error,
143 >,
144 >
145 where
146 't: 'e,
147 {
148 (&mut **self).prepare_with(sql, parameters)
149 }
150
151 #[doc(hidden)]
152 fn describe<'e, 'q: 'e>(
153 self,
154 query: &'q str,
155 ) -> futures_core::future::BoxFuture<
156 'e,
157 Result<crate::describe::Describe<Self::Database>, crate::error::Error>,
158 >
159 where
160 't: 'e,
161 {
162 (&mut **self).describe(query)
163 }
164 }
165 };
166}
167
168impl<'c, DB> Debug for Transaction<'c, DB>
169where
170 DB: Database,
171{
172 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
173 f.debug_struct("Transaction").finish()
175 }
176}
177
178impl<'c, DB> Deref for Transaction<'c, DB>
179where
180 DB: Database,
181{
182 type Target = DB::Connection;
183
184 #[inline]
185 fn deref(&self) -> &Self::Target {
186 &self.connection
187 }
188}
189
190impl<'c, DB> DerefMut for Transaction<'c, DB>
191where
192 DB: Database,
193{
194 #[inline]
195 fn deref_mut(&mut self) -> &mut Self::Target {
196 &mut self.connection
197 }
198}
199
200impl<'c, DB> Drop for Transaction<'c, DB>
201where
202 DB: Database,
203{
204 fn drop(&mut self) {
205 if self.open {
206 DB::TransactionManager::start_rollback(&mut self.connection);
213 }
214 }
215}
216
217#[allow(dead_code)]
218pub(crate) fn begin_ansi_transaction_sql(depth: usize) -> Cow<'static, str> {
219 if depth == 0 {
220 Cow::Borrowed("BEGIN")
221 } else {
222 Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{}", depth))
223 }
224}
225
226#[allow(dead_code)]
227pub(crate) fn commit_ansi_transaction_sql(depth: usize) -> Cow<'static, str> {
228 if depth == 1 {
229 Cow::Borrowed("COMMIT")
230 } else {
231 Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1))
232 }
233}
234
235#[allow(dead_code)]
236pub(crate) fn rollback_ansi_transaction_sql(depth: usize) -> Cow<'static, str> {
237 if depth == 1 {
238 Cow::Borrowed("ROLLBACK")
239 } else {
240 Cow::Owned(format!(
241 "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}",
242 depth - 1
243 ))
244 }
245}