zero_postgres/tokio/
transaction.rs

1//! Transaction support for asynchronous PostgreSQL connections.
2
3use super::Conn;
4use super::named_portal::NamedPortal;
5use crate::conversion::ToParams;
6use crate::error::{Error, Result};
7use crate::statement::IntoStatement;
8
9/// A PostgreSQL transaction for the asynchronous connection.
10///
11/// This struct provides transaction control. The connection is passed
12/// to `commit` and `rollback` methods to execute the transaction commands.
13pub struct Transaction {
14    connection_id: u32,
15}
16
17impl Transaction {
18    /// Create a new transaction (internal use only).
19    pub(crate) fn new(connection_id: u32) -> Self {
20        Self { connection_id }
21    }
22
23    /// Commit the transaction.
24    ///
25    /// This consumes the transaction and sends a COMMIT statement to the server.
26    /// The connection must be passed as an argument to execute the commit.
27    ///
28    /// # Errors
29    ///
30    /// Returns `Error::InvalidUsage` if the connection is not the same
31    /// as the one that started the transaction.
32    pub async fn commit(self, conn: &mut Conn) -> Result<()> {
33        let actual = conn.connection_id();
34        if self.connection_id != actual {
35            return Err(Error::InvalidUsage(format!(
36                "connection mismatch: expected {}, got {}",
37                self.connection_id, actual
38            )));
39        }
40        conn.query_drop("COMMIT").await?;
41        Ok(())
42    }
43
44    /// Rollback the transaction.
45    ///
46    /// This consumes the transaction and sends a ROLLBACK statement to the server.
47    /// The connection must be passed as an argument to execute the rollback.
48    ///
49    /// # Errors
50    ///
51    /// Returns `Error::InvalidUsage` if the connection is not the same
52    /// as the one that started the transaction.
53    pub async fn rollback(self, conn: &mut Conn) -> Result<()> {
54        let actual = conn.connection_id();
55        if self.connection_id != actual {
56            return Err(Error::InvalidUsage(format!(
57                "connection mismatch: expected {}, got {}",
58                self.connection_id, actual
59            )));
60        }
61        conn.query_drop("ROLLBACK").await?;
62        Ok(())
63    }
64
65    /// Create a named portal for iterative row fetching within this transaction.
66    ///
67    /// Named portals are safe to use within an explicit transaction because
68    /// SYNC messages do not destroy them (only COMMIT/ROLLBACK does).
69    ///
70    /// The statement can be either:
71    /// - A `&PreparedStatement` returned from `conn.prepare()`
72    /// - A raw SQL `&str` for one-shot execution
73    ///
74    /// # Example
75    ///
76    /// ```ignore
77    /// conn.transaction(|conn, tx| async move {
78    ///     let mut portal = tx.exec_portal_named(conn, &stmt, ()).await?;
79    ///
80    ///     while !portal.is_complete() {
81    ///         let rows: Vec<(i32,)> = portal.exec_collect(conn, 100).await?;
82    ///         process(rows);
83    ///     }
84    ///
85    ///     portal.close(conn).await?;
86    ///     tx.commit(conn).await
87    /// }).await?;
88    /// ```
89    ///
90    /// # Errors
91    ///
92    /// Returns `Error::InvalidUsage` if the connection is not the same
93    /// as the one that started the transaction.
94    pub async fn exec_portal_named<S: IntoStatement, P: ToParams>(
95        &self,
96        conn: &mut Conn,
97        statement: S,
98        params: P,
99    ) -> Result<NamedPortal<'_>> {
100        let actual = conn.connection_id();
101        if self.connection_id != actual {
102            return Err(Error::InvalidUsage(format!(
103                "connection mismatch: expected {}, got {}",
104                self.connection_id, actual
105            )));
106        }
107
108        let portal_name = conn.next_portal_name();
109        let result = conn
110            .create_named_portal(&portal_name, &statement, &params)
111            .await;
112
113        if let Err(e) = &result {
114            if e.is_connection_broken() {
115                conn.is_broken = true;
116            }
117            return Err(result.unwrap_err());
118        }
119
120        Ok(NamedPortal::new(portal_name))
121    }
122}