xitca_postgres/
transaction.rs

1mod builder;
2mod portal;
3
4use std::borrow::Cow;
5
6use super::{
7    client::ClientBorrowMut,
8    driver::codec::{encode::Encode, AsParams, Response},
9    error::Error,
10    execute::Execute,
11    prepare::Prepare,
12    query::Query,
13    statement::Statement,
14    types::{Oid, ToSql, Type},
15    BoxedFuture,
16};
17
18pub use builder::TransactionBuilder;
19pub use portal::Portal;
20
21pub struct Transaction<'a, C>
22where
23    C: Prepare + ClientBorrowMut,
24{
25    client: &'a mut C,
26    save_point: SavePoint,
27    state: State,
28}
29
30enum SavePoint {
31    None,
32    Auto { depth: u32 },
33    Custom { name: String, depth: u32 },
34}
35
36impl SavePoint {
37    fn nest_save_point(&self, name: Option<String>) -> Self {
38        match *self {
39            Self::None => match name {
40                Some(name) => SavePoint::Custom { name, depth: 1 },
41                None => SavePoint::Auto { depth: 1 },
42            },
43            Self::Auto { depth } | Self::Custom { depth, .. } => match name {
44                Some(name) => SavePoint::Custom { name, depth },
45                None => SavePoint::Auto { depth: depth + 1 },
46            },
47        }
48    }
49
50    fn save_point_query(&self) -> Cow<'static, str> {
51        match self {
52            Self::None => Cow::Borrowed("SAVEPOINT"),
53            Self::Auto { depth } => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
54            Self::Custom { name, .. } => Cow::Owned(format!("SAVEPOINT {name}")),
55        }
56    }
57
58    fn commit_query(&self) -> Cow<'static, str> {
59        match self {
60            Self::None => Cow::Borrowed("COMMIT"),
61            Self::Auto { depth } => Cow::Owned(format!("RELEASE sp_{depth}")),
62            Self::Custom { name, .. } => Cow::Owned(format!("RELEASE {name}")),
63        }
64    }
65
66    fn rollback_query(&self) -> Cow<'static, str> {
67        match self {
68            Self::None => Cow::Borrowed("ROLLBACK"),
69            Self::Auto { depth } => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
70            Self::Custom { name, .. } => Cow::Owned(format!("ROLLBACK TO {name}")),
71        }
72    }
73}
74
75enum State {
76    WantRollback,
77    Finish,
78}
79
80impl<C> Drop for Transaction<'_, C>
81where
82    C: Prepare + ClientBorrowMut,
83{
84    fn drop(&mut self) {
85        match self.state {
86            State::WantRollback => self.do_rollback(),
87            State::Finish => {}
88        }
89    }
90}
91
92impl<C> Transaction<'_, C>
93where
94    C: Prepare + ClientBorrowMut,
95{
96    pub fn builder() -> TransactionBuilder {
97        TransactionBuilder::new()
98    }
99
100    /// Binds a statement to a set of parameters, creating a [`Portal`] which can be incrementally queried.
101    ///
102    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
103    /// connection that created them.
104    pub async fn bind<'p>(
105        &'p self,
106        statement: &'p Statement,
107        params: &[&(dyn ToSql + Sync)],
108    ) -> Result<Portal<'p, C>, Error> {
109        self.bind_raw(statement, params.iter().cloned()).await
110    }
111
112    /// A maximally flexible version of [`Transaction::bind`].
113    pub async fn bind_raw<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
114    where
115        I: AsParams,
116    {
117        Portal::new(self.client, statement, params).await
118    }
119
120    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint.
121    ///     
122    /// [`Client::transaction`]: crate::client::Client::transaction
123    pub async fn transaction(&mut self) -> Result<Transaction<C>, Error> {
124        self._save_point(None).await
125    }
126
127    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint with the specified name.
128    ///
129    /// [`Client::transaction`]: crate::client::Client::transaction
130    pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<C>, Error>
131    where
132        I: Into<String>,
133    {
134        self._save_point(Some(name.into())).await
135    }
136
137    /// Consumes the transaction, committing all changes made within it.
138    pub async fn commit(mut self) -> Result<(), Error> {
139        self.state = State::Finish;
140        self.save_point.commit_query().execute(&self).await?;
141        Ok(())
142    }
143
144    /// Rolls the transaction back, discarding all changes made within it.
145    ///
146    /// This is equivalent to [`Transaction`]'s [`Drop`] implementation, but provides any error encountered to the caller.
147    pub async fn rollback(mut self) -> Result<(), Error> {
148        self.state = State::Finish;
149        self.save_point.rollback_query().execute(&self).await?;
150        Ok(())
151    }
152
153    fn new(client: &mut C) -> Transaction<C> {
154        Transaction {
155            client,
156            save_point: SavePoint::None,
157            state: State::WantRollback,
158        }
159    }
160
161    async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<C>, Error> {
162        let save_point = self.save_point.nest_save_point(name);
163        save_point.save_point_query().execute(self).await?;
164
165        Ok(Transaction {
166            client: self.client,
167            save_point,
168            state: State::WantRollback,
169        })
170    }
171
172    fn do_rollback(&mut self) {
173        drop(self.save_point.rollback_query().execute(self));
174    }
175}
176
177impl<C> Prepare for Transaction<'_, C>
178where
179    C: Prepare + ClientBorrowMut,
180{
181    #[inline]
182    fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
183        self.client._get_type(oid)
184    }
185
186    #[inline]
187    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
188        self.client._get_type_blocking(oid)
189    }
190}
191
192impl<C> Query for Transaction<'_, C>
193where
194    C: Prepare + ClientBorrowMut,
195{
196    #[inline]
197    fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
198    where
199        S: Encode,
200    {
201        self.client._send_encode_query(stmt)
202    }
203}