uni_db/api/
transaction.rs1use std::collections::HashMap;
5use std::time::Instant;
6
7use futures::future::BoxFuture;
8use metrics;
9use tracing::{error, info, instrument, warn};
10use uuid::Uuid;
11
12use crate::api::Uni;
13use uni_common::{Result, UniError};
14use uni_query::{ExecuteResult, QueryResult, Value};
15
16pub struct Transaction<'a> {
31 db: &'a Uni,
32 completed: bool,
33 id: String,
34 start_time: Instant,
35}
36
37impl<'a> Transaction<'a> {
38 pub(crate) async fn new(db: &'a Uni) -> Result<Self> {
39 let writer_lock = db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
40 operation: "start_transaction".to_string(),
41 })?;
42 let mut writer = writer_lock.write().await;
43 writer.begin_transaction()?;
44 let id = Uuid::new_v4().to_string();
45 info!(transaction_id = %id, "Transaction started");
46 Ok(Self {
47 db,
48 completed: false,
49 id,
50 start_time: Instant::now(),
51 })
52 }
53
54 pub async fn query(&self, cypher: &str) -> Result<QueryResult> {
64 self.db.execute_internal(cypher, HashMap::new()).await
65 }
66
67 pub async fn execute(&self, cypher: &str) -> Result<ExecuteResult> {
77 let before = self.db.get_mutation_count().await;
78 let result = self.query(cypher).await?;
79 let affected_rows = if result.is_empty() {
80 self.db.get_mutation_count().await.saturating_sub(before)
81 } else {
82 result.len()
83 };
84 Ok(ExecuteResult { affected_rows })
85 }
86
87 pub fn execute_with(&self, cypher: &str) -> TransactionQueryBuilder<'_> {
92 TransactionQueryBuilder {
93 tx: self,
94 cypher: cypher.to_string(),
95 params: HashMap::new(),
96 }
97 }
98
99 #[instrument(skip(self), fields(transaction_id = %self.id, duration_ms), level = "info")]
104 pub async fn commit(mut self) -> Result<()> {
105 if self.completed {
106 return Err(uni_common::UniError::TransactionAlreadyCompleted);
107 }
108 let writer_lock = self.db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
109 operation: "commit".to_string(),
110 })?;
111 let mut writer = writer_lock.write().await;
112 writer.commit_transaction().await?;
113 self.completed = true;
114 let duration = self.start_time.elapsed();
115 tracing::Span::current().record("duration_ms", duration.as_millis());
116 metrics::histogram!("uni_transaction_duration_seconds").record(duration.as_secs_f64());
117 metrics::counter!("uni_transaction_commits_total").increment(1);
118 info!("Transaction committed");
119 Ok(())
120 }
121
122 #[instrument(skip(self), fields(transaction_id = %self.id, duration_ms), level = "info")]
126 pub async fn rollback(mut self) -> Result<()> {
127 if self.completed {
128 return Err(uni_common::UniError::TransactionAlreadyCompleted);
129 }
130 let writer_lock = self.db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
131 operation: "rollback".to_string(),
132 })?;
133 let mut writer = writer_lock.write().await;
134 writer.rollback_transaction()?;
135 self.completed = true;
136 let duration = self.start_time.elapsed();
137 tracing::Span::current().record("duration_ms", duration.as_millis());
138 metrics::histogram!("uni_transaction_duration_seconds").record(duration.as_secs_f64());
139 metrics::counter!("uni_transaction_rollbacks_total").increment(1);
140 info!("Transaction rolled back");
141 Ok(())
142 }
143}
144
145impl Drop for Transaction<'_> {
146 fn drop(&mut self) {
147 if !self.completed {
148 warn!(
149 transaction_id = %self.id,
150 "Transaction dropped without commit or rollback — auto-rolling back"
151 );
152 if let Some(writer_lock) = self.db.writer.as_ref() {
153 match writer_lock.try_write() {
155 Ok(mut writer) => writer.force_rollback(),
156 Err(_) => error!(
157 transaction_id = %self.id,
158 "Could not acquire writer lock for auto-rollback"
159 ),
160 }
161 }
162 }
163 }
164}
165
166pub struct TransactionQueryBuilder<'a> {
168 tx: &'a Transaction<'a>,
169 cypher: String,
170 params: HashMap<String, Value>,
171}
172
173impl<'a> TransactionQueryBuilder<'a> {
174 pub fn param(mut self, name: &str, value: impl Into<Value>) -> Self {
176 self.params.insert(name.to_string(), value.into());
177 self
178 }
179
180 pub async fn execute(self) -> Result<ExecuteResult> {
182 let before = self.tx.db.get_mutation_count().await;
183 let result = self
184 .tx
185 .db
186 .execute_internal(&self.cypher, self.params)
187 .await?;
188 let affected_rows = if result.is_empty() {
189 self.tx.db.get_mutation_count().await.saturating_sub(before)
190 } else {
191 result.len()
192 };
193 Ok(ExecuteResult { affected_rows })
194 }
195}
196
197impl Uni {
198 pub async fn begin(&self) -> Result<Transaction<'_>> {
199 Transaction::new(self).await
200 }
201
202 pub async fn transaction<'a, F, T>(&'a self, f: F) -> Result<T>
203 where
204 F: for<'b> FnOnce(&'b mut Transaction<'a>) -> BoxFuture<'b, Result<T>>,
205 {
206 let mut tx = self.begin().await?;
207
208 match f(&mut tx).await {
209 Ok(v) => match tx.commit().await {
210 Ok(_) => Ok(v),
211 Err(uni_common::UniError::TransactionAlreadyCompleted) => Ok(v),
212 Err(e) => Err(e),
213 },
214 Err(e) => {
215 if let Err(rollback_err) = tx.rollback().await {
217 error!(
218 "Transaction rollback failed during error recovery: {}",
219 rollback_err
220 );
221 }
222 Err(e)
223 }
224 }
225 }
226}