xitca_postgres/
transaction.rs

1mod builder;
2mod portal;
3
4use std::borrow::Cow;
5
6use super::{
7    client::ClientBorrowMut,
8    driver::codec::{AsParams, Response, encode::Encode},
9    error::Error,
10    execute::Execute,
11    prepare::Prepare,
12    query::Query,
13    statement::Statement,
14    types::{Oid, ToSql, Type},
15};
16
17pub use builder::{IsolationLevel, TransactionBuilder};
18pub use portal::Portal;
19
20pub struct Transaction<C>
21where
22    C: Prepare + ClientBorrowMut,
23{
24    client: C,
25    save_point: SavePoint,
26    state: State,
27}
28
29enum SavePoint {
30    None,
31    Auto { depth: u32 },
32    Custom { name: String, depth: u32 },
33}
34
35impl SavePoint {
36    fn nest_save_point(&self, name: Option<String>) -> Self {
37        match *self {
38            Self::None => match name {
39                Some(name) => SavePoint::Custom { name, depth: 1 },
40                None => SavePoint::Auto { depth: 1 },
41            },
42            Self::Auto { depth } | Self::Custom { depth, .. } => match name {
43                Some(name) => SavePoint::Custom { name, depth },
44                None => SavePoint::Auto { depth: depth + 1 },
45            },
46        }
47    }
48
49    fn save_point_query(&self) -> Cow<'static, str> {
50        match self {
51            Self::None => Cow::Borrowed("SAVEPOINT"),
52            Self::Auto { depth } => Cow::Owned(format!("SAVEPOINT sp_{depth}")),
53            Self::Custom { name, .. } => Cow::Owned(format!("SAVEPOINT {name}")),
54        }
55    }
56
57    fn commit_query(&self) -> Cow<'static, str> {
58        match self {
59            Self::None => Cow::Borrowed("COMMIT"),
60            Self::Auto { depth } => Cow::Owned(format!("RELEASE sp_{depth}")),
61            Self::Custom { name, .. } => Cow::Owned(format!("RELEASE {name}")),
62        }
63    }
64
65    fn rollback_query(&self) -> Cow<'static, str> {
66        match self {
67            Self::None => Cow::Borrowed("ROLLBACK"),
68            Self::Auto { depth } => Cow::Owned(format!("ROLLBACK TO sp_{depth}")),
69            Self::Custom { name, .. } => Cow::Owned(format!("ROLLBACK TO {name}")),
70        }
71    }
72}
73
74enum State {
75    WantRollback,
76    Finish,
77}
78
79impl<C> Drop for Transaction<C>
80where
81    C: Prepare + ClientBorrowMut,
82{
83    fn drop(&mut self) {
84        match self.state {
85            State::WantRollback => self.do_rollback(),
86            State::Finish => {}
87        }
88    }
89}
90
91impl<C> Transaction<C>
92where
93    C: Prepare + ClientBorrowMut,
94{
95    /// Binds a statement to a set of parameters, creating a [`Portal`] which can be incrementally queried.
96    ///
97    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
98    /// connection that created them.
99    pub async fn bind<'p>(
100        &'p self,
101        statement: &'p Statement,
102        params: &[&(dyn ToSql + Sync)],
103    ) -> Result<Portal<'p, C>, Error> {
104        self.bind_raw(statement, params.iter().cloned()).await
105    }
106
107    /// A maximally flexible version of [`Transaction::bind`].
108    pub async fn bind_raw<'p, I>(&'p self, statement: &'p Statement, params: I) -> Result<Portal<'p, C>, Error>
109    where
110        I: AsParams,
111    {
112        Portal::new(&self.client, statement, params).await
113    }
114
115    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint.
116    ///     
117    /// [`Client::transaction`]: crate::client::Client::transaction
118    pub async fn transaction(&mut self) -> Result<Transaction<&mut C>, Error> {
119        self._save_point(None).await
120    }
121
122    /// Like [`Client::transaction`], but creates a nested transaction via a savepoint with the specified name.
123    ///
124    /// [`Client::transaction`]: crate::client::Client::transaction
125    pub async fn save_point<I>(&mut self, name: I) -> Result<Transaction<&mut C>, Error>
126    where
127        I: Into<String>,
128    {
129        self._save_point(Some(name.into())).await
130    }
131
132    /// Consumes the transaction, committing all changes made within it.
133    pub async fn commit(mut self) -> Result<(), Error> {
134        self.state = State::Finish;
135        self.save_point.commit_query().execute(&self).await?;
136        Ok(())
137    }
138
139    /// Rolls the transaction back, discarding all changes made within it.
140    ///
141    /// This is equivalent to [`Transaction`]'s [`Drop`] implementation, but provides any error encountered to the caller.
142    pub async fn rollback(mut self) -> Result<(), Error> {
143        self.state = State::Finish;
144        self.save_point.rollback_query().execute(&self).await?;
145        Ok(())
146    }
147
148    fn new(client: C) -> Transaction<C> {
149        Transaction {
150            client,
151            save_point: SavePoint::None,
152            state: State::WantRollback,
153        }
154    }
155
156    async fn _save_point(&mut self, name: Option<String>) -> Result<Transaction<&mut C>, Error> {
157        let save_point = self.save_point.nest_save_point(name);
158        save_point.save_point_query().execute(self).await?;
159
160        Ok(Transaction {
161            client: &mut self.client,
162            save_point,
163            state: State::WantRollback,
164        })
165    }
166
167    fn do_rollback(&mut self) {
168        drop(self.save_point.rollback_query().execute(self));
169    }
170}
171
172impl<C> Prepare for Transaction<C>
173where
174    C: Prepare + ClientBorrowMut,
175{
176    #[inline]
177    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
178        self.client._get_type(oid).await
179    }
180
181    #[inline]
182    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
183        self.client._get_type_blocking(oid)
184    }
185}
186
187impl<C> Query for Transaction<C>
188where
189    C: Prepare + ClientBorrowMut,
190{
191    #[inline]
192    fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
193    where
194        S: Encode,
195    {
196        self.client._send_encode_query(stmt)
197    }
198}