zero_postgres/tokio/transaction.rs
1//! Transaction support for asynchronous PostgreSQL connections.
2
3use super::Conn;
4use super::named_portal::NamedPortal;
5use crate::conversion::ToParams;
6use crate::error::{Error, Result};
7use crate::statement::IntoStatement;
8
9/// A PostgreSQL transaction for the asynchronous connection.
10///
11/// This struct provides transaction control. The connection is passed
12/// to `commit` and `rollback` methods to execute the transaction commands.
13pub struct Transaction {
14 connection_id: u32,
15}
16
17impl Transaction {
18 /// Create a new transaction (internal use only).
19 pub(crate) fn new(connection_id: u32) -> Self {
20 Self { connection_id }
21 }
22
23 /// Commit the transaction.
24 ///
25 /// This consumes the transaction and sends a COMMIT statement to the server.
26 /// The connection must be passed as an argument to execute the commit.
27 ///
28 /// # Errors
29 ///
30 /// Returns `Error::InvalidUsage` if the connection is not the same
31 /// as the one that started the transaction.
32 pub async fn commit(self, conn: &mut Conn) -> Result<()> {
33 let actual = conn.connection_id();
34 if self.connection_id != actual {
35 return Err(Error::InvalidUsage(format!(
36 "connection mismatch: expected {}, got {}",
37 self.connection_id, actual
38 )));
39 }
40 conn.query_drop("COMMIT").await?;
41 Ok(())
42 }
43
44 /// Rollback the transaction.
45 ///
46 /// This consumes the transaction and sends a ROLLBACK statement to the server.
47 /// The connection must be passed as an argument to execute the rollback.
48 ///
49 /// # Errors
50 ///
51 /// Returns `Error::InvalidUsage` if the connection is not the same
52 /// as the one that started the transaction.
53 pub async fn rollback(self, conn: &mut Conn) -> Result<()> {
54 let actual = conn.connection_id();
55 if self.connection_id != actual {
56 return Err(Error::InvalidUsage(format!(
57 "connection mismatch: expected {}, got {}",
58 self.connection_id, actual
59 )));
60 }
61 conn.query_drop("ROLLBACK").await?;
62 Ok(())
63 }
64
65 /// Create a named portal for iterative row fetching within this transaction.
66 ///
67 /// Named portals are safe to use within an explicit transaction because
68 /// SYNC messages do not destroy them (only COMMIT/ROLLBACK does).
69 ///
70 /// The statement can be either:
71 /// - A `&PreparedStatement` returned from `conn.prepare()`
72 /// - A raw SQL `&str` for one-shot execution
73 ///
74 /// # Example
75 ///
76 /// ```ignore
77 /// conn.transaction(|conn, tx| async move {
78 /// let mut portal = tx.exec_portal_named(conn, &stmt, ()).await?;
79 ///
80 /// while !portal.is_complete() {
81 /// let rows: Vec<(i32,)> = portal.execute_collect(conn, 100).await?;
82 /// process(rows);
83 /// }
84 ///
85 /// portal.close(conn).await?;
86 /// tx.commit(conn).await
87 /// }).await?;
88 /// ```
89 ///
90 /// # Errors
91 ///
92 /// Returns `Error::InvalidUsage` if the connection is not the same
93 /// as the one that started the transaction.
94 pub async fn exec_portal_named<S: IntoStatement, P: ToParams>(
95 &self,
96 conn: &mut Conn,
97 statement: S,
98 params: P,
99 ) -> Result<NamedPortal<'_>> {
100 let actual = conn.connection_id();
101 if self.connection_id != actual {
102 return Err(Error::InvalidUsage(format!(
103 "connection mismatch: expected {}, got {}",
104 self.connection_id, actual
105 )));
106 }
107
108 let portal_name = conn.next_portal_name();
109 let result = conn
110 .create_named_portal(&portal_name, &statement, ¶ms)
111 .await;
112
113 if let Err(e) = &result {
114 if e.is_connection_broken() {
115 conn.is_broken = true;
116 }
117 return Err(result.unwrap_err());
118 }
119
120 Ok(NamedPortal::new(portal_name))
121 }
122}