Skip to main content

xitca_postgres/
transaction.rs

1mod builder;
2mod portal;
3
4use core::ops::{Deref, DerefMut};
5
6use std::borrow::Cow;
7
8use super::{
9    client::{Client, ClientBorrowMut},
10    driver::codec::AsParams,
11    error::Error,
12    execute::Execute,
13    pool::PoolConnection,
14    statement::Statement,
15    types::ToSql,
16};
17
18pub use builder::{IsolationLevel, TransactionBuilder};
19pub use portal::Portal;
20
21struct SavePoint {
22    name: Option<String>,
23    depth: u32,
24    state: State,
25}
26
27enum State {
28    WantRollback,
29    Finish,
30}
31
32impl Default for SavePoint {
33    fn default() -> Self {
34        Self {
35            name: None,
36            depth: 0,
37            state: State::WantRollback,
38        }
39    }
40}
41
42impl SavePoint {
43    // rollback runs in Drop trait impl. the execution part has to run eargerly
44    fn rollback(&mut self, cli: impl ClientBorrowMut) -> impl Future<Output = Result<(), Error>> + Send {
45        self.state = State::Finish;
46
47        let fut = match self.depth {
48            0 => Cow::Borrowed("ROLLBACK"),
49            depth => match self.name {
50                None => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
51                Some(ref name) => Cow::Owned(format!("ROLLBACK TO {name}")),
52            },
53        }
54        .execute(cli.borrow_cli_ref());
55
56        async { fut.await.map(|_| ()) }
57    }
58
59    async fn commit(&mut self, cli: impl ClientBorrowMut) -> Result<(), Error> {
60        self.state = State::Finish;
61
62        match self.depth {
63            0 => Cow::Borrowed("COMMIT"),
64            depth => match self.name {
65                None => Cow::Owned(format!("RELEASE sp_{depth}")),
66                Some(ref name) => Cow::Owned(format!("RELEASE {name}")),
67            },
68        }
69        .execute(cli.borrow_cli_ref())
70        .await
71        .map(|_| ())
72    }
73
74    async fn nest_save_point(&self, cli: impl ClientBorrowMut, name: Option<String>) -> Result<Self, Error> {
75        let depth = self.depth + 1;
76
77        match self.depth {
78            0 => match name {
79                Some(ref name) => Cow::Owned(format!("SAVEPOINT {name}")),
80                None => Cow::Borrowed("SAVEPOINT sp_1"),
81            },
82            depth => match name {
83                Some(ref name) => Cow::Owned(format!("SAVEPOINT {name}")),
84                None => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
85            },
86        }
87        .execute(cli.borrow_cli_ref())
88        .await
89        .map(|_| SavePoint {
90            name,
91            depth,
92            state: State::WantRollback,
93        })
94    }
95
96    fn on_drop(&mut self, cli: impl ClientBorrowMut) {
97        if matches!(self.state, State::WantRollback) {
98            drop(self.rollback(cli));
99        }
100    }
101}
102
103pub struct Transaction<'a, C>
104where
105    C: ClientBorrowMut,
106{
107    client: _Client<'a, C>,
108    save_point: SavePoint,
109}
110
111enum _Client<'a, C> {
112    Owned(C),
113    Borrowed(&'a mut C),
114}
115
116impl<C> _Client<'_, C> {
117    #[inline]
118    fn reborrow(&mut self) -> _Client<'_, C> {
119        _Client::Borrowed(self.deref_mut())
120    }
121}
122
123impl<C> Deref for _Client<'_, C> {
124    type Target = C;
125
126    fn deref(&self) -> &Self::Target {
127        match self {
128            Self::Borrowed(c) => c,
129            Self::Owned(c) => c,
130        }
131    }
132}
133
134impl<C> DerefMut for _Client<'_, C> {
135    fn deref_mut(&mut self) -> &mut Self::Target {
136        match self {
137            Self::Borrowed(c) => c,
138            Self::Owned(c) => c,
139        }
140    }
141}
142
143impl<C> Drop for Transaction<'_, C>
144where
145    C: ClientBorrowMut,
146{
147    fn drop(&mut self) {
148        self.save_point.on_drop(self.client.deref_mut());
149    }
150}
151
152impl<C> Transaction<'_, C>
153where
154    C: ClientBorrowMut,
155{
156    /// A maximally flexible version of [`Transaction::bind`].
157    pub async fn bind<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
158    where
159        I: AsParams,
160    {
161        Portal::new(&*self.client, statement, params).await
162    }
163
164    /// Binds a statement to a set of parameters, creating a [`Portal`] which can be incrementally queried.
165    ///
166    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
167    /// connection that created them.
168    pub async fn bind_dyn<'p>(
169        &'p self,
170        statement: &'p Statement,
171        params: &[&(dyn ToSql + Sync)],
172    ) -> Result<Portal<'p, C>, Error> {
173        self.bind(statement, params.iter().cloned()).await
174    }
175
176    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint.
177    ///     
178    /// [`Client::transaction`]: crate::client::Client::transaction
179    pub async fn transaction(&mut self) -> Result<Transaction<'_, C>, Error> {
180        self._save_point(None).await
181    }
182
183    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint with the specified name.
184    ///
185    /// [`Client::transaction`]: crate::client::Client::transaction
186    pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<'_, C>, Error>
187    where
188        I: Into<String>,
189    {
190        self._save_point(Some(name.into())).await
191    }
192
193    /// Consumes the transaction, committing all changes made within it.
194    pub async fn commit(mut self) -> Result<(), Error> {
195        self.save_point.commit(self.client.deref_mut()).await
196    }
197
198    /// Rolls the transaction back, discarding all changes made within it.
199    ///
200    /// This is equivalent to [`Transaction`]'s [`Drop`] implementation, but provides any error encountered to the caller.
201    pub async fn rollback(mut self) -> Result<(), Error> {
202        self.save_point.rollback(self.client.deref_mut()).await
203    }
204
205    fn new(client: &mut C) -> Transaction<'_, C> {
206        Transaction {
207            client: _Client::Borrowed(client),
208            save_point: SavePoint::default(),
209        }
210    }
211
212    fn new_owned<'a>(client: C) -> Transaction<'a, C>
213    where
214        C: 'a,
215    {
216        Transaction {
217            client: _Client::Owned(client),
218            save_point: SavePoint::default(),
219        }
220    }
221
222    async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<'_, C>, Error> {
223        let save_point = self.save_point.nest_save_point(self.client.deref_mut(), name).await?;
224        Ok(Transaction {
225            client: self.client.reborrow(),
226            save_point,
227        })
228    }
229}
230
231impl<'c, C, Q, EO, QO> Execute<&'c Transaction<'_, C>> for Q
232where
233    C: ClientBorrowMut,
234    Q: Execute<&'c Client, ExecuteOutput = EO, QueryOutput = QO>,
235{
236    type ExecuteOutput = EO;
237    type QueryOutput = QO;
238
239    #[inline]
240    fn execute(self, cli: &'c Transaction<'_, C>) -> Self::ExecuteOutput {
241        Q::execute(self, cli.client.borrow_cli_ref())
242    }
243
244    #[inline]
245    fn query(self, cli: &'c Transaction<C>) -> Self::QueryOutput {
246        Q::query(self, cli.client.borrow_cli_ref())
247    }
248}
249
250// special treatment for pool connection for it's internal caching logic that are not accessible through Query trait
251impl<'c, 'p, Q, EO, QO> Execute<&'c mut Transaction<'_, PoolConnection<'p>>> for Q
252where
253    Q: Execute<&'c mut PoolConnection<'p>, ExecuteOutput = EO, QueryOutput = QO>,
254{
255    type ExecuteOutput = EO;
256    type QueryOutput = QO;
257
258    #[inline]
259    fn execute(self, cli: &'c mut Transaction<'_, PoolConnection<'p>>) -> Self::ExecuteOutput {
260        Q::execute(self, cli.client.deref_mut())
261    }
262
263    #[inline]
264    fn query(self, cli: &'c mut Transaction<PoolConnection<'p>>) -> Self::QueryOutput {
265        Q::query(self, cli.client.deref_mut())
266    }
267}