Skip to main content

uni_db/api/
transaction.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use std::collections::HashMap;
5use std::time::Instant;
6
7use futures::future::BoxFuture;
8use metrics;
9use tracing::{error, info, instrument, warn};
10use uuid::Uuid;
11
12use crate::api::Uni;
13use uni_common::{Result, UniError};
14use uni_query::{ExecuteResult, QueryResult, Value};
15
16/// A database transaction.
17///
18/// Transactions provide ACID guarantees for multiple operations.
19/// Changes are isolated until commit.
20///
21/// # Isolation Level
22///
23/// Uni uses Snapshot Isolation. Reads see a consistent snapshot of the database
24/// at the start of the transaction. Writes are buffered and applied atomically on commit.
25///
26/// # Concurrency
27///
28/// Only one write transaction is active at a time (Single Writer).
29/// Read-only transactions can run concurrently.
30pub struct Transaction<'a> {
31    db: &'a Uni,
32    completed: bool,
33    id: String,
34    start_time: Instant,
35}
36
37impl<'a> Transaction<'a> {
38    pub(crate) async fn new(db: &'a Uni) -> Result<Self> {
39        let writer_lock = db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
40            operation: "start_transaction".to_string(),
41        })?;
42        let mut writer = writer_lock.write().await;
43        writer.begin_transaction()?;
44        let id = Uuid::new_v4().to_string();
45        info!(transaction_id = %id, "Transaction started");
46        Ok(Self {
47            db,
48            completed: false,
49            id,
50            start_time: Instant::now(),
51        })
52    }
53
54    /// Execute a Cypher query within the transaction.
55    ///
56    /// # Arguments
57    ///
58    /// * `cypher` - The Cypher query string.
59    ///
60    /// # Returns
61    ///
62    /// A [`QueryResult`] containing rows and columns.
63    pub async fn query(&self, cypher: &str) -> Result<QueryResult> {
64        self.db.execute_internal(cypher, HashMap::new()).await
65    }
66
67    /// Execute a Cypher query that doesn't return rows (e.g. CREATE, DELETE).
68    ///
69    /// # Arguments
70    ///
71    /// * `cypher` - The Cypher query string.
72    ///
73    /// # Returns
74    ///
75    /// An [`ExecuteResult`] with statistics on affected rows.
76    pub async fn execute(&self, cypher: &str) -> Result<ExecuteResult> {
77        let before = self.db.get_mutation_count().await;
78        let result = self.query(cypher).await?;
79        let affected_rows = if result.is_empty() {
80            self.db.get_mutation_count().await.saturating_sub(before)
81        } else {
82            result.len()
83        };
84        Ok(ExecuteResult { affected_rows })
85    }
86
87    /// Execute a mutation with parameters using a builder.
88    ///
89    /// This is the mutation counterpart to [`query`](Self::query) with params support.
90    /// Use `.param()` to bind parameters, then call `.fetch_all()` or similar.
91    pub fn execute_with(&self, cypher: &str) -> TransactionQueryBuilder<'_> {
92        TransactionQueryBuilder {
93            tx: self,
94            cypher: cypher.to_string(),
95            params: HashMap::new(),
96        }
97    }
98
99    /// Commit the transaction.
100    ///
101    /// Persists all changes made during the transaction.
102    /// If commit fails, the transaction is rolled back.
103    #[instrument(skip(self), fields(transaction_id = %self.id, duration_ms), level = "info")]
104    pub async fn commit(mut self) -> Result<()> {
105        if self.completed {
106            return Err(uni_common::UniError::TransactionAlreadyCompleted);
107        }
108        let writer_lock = self.db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
109            operation: "commit".to_string(),
110        })?;
111        let mut writer = writer_lock.write().await;
112        writer.commit_transaction().await?;
113        self.completed = true;
114        let duration = self.start_time.elapsed();
115        tracing::Span::current().record("duration_ms", duration.as_millis());
116        metrics::histogram!("uni_transaction_duration_seconds").record(duration.as_secs_f64());
117        metrics::counter!("uni_transaction_commits_total").increment(1);
118        info!("Transaction committed");
119        Ok(())
120    }
121
122    /// Rollback the transaction.
123    ///
124    /// Discards all changes made during the transaction.
125    #[instrument(skip(self), fields(transaction_id = %self.id, duration_ms), level = "info")]
126    pub async fn rollback(mut self) -> Result<()> {
127        if self.completed {
128            return Err(uni_common::UniError::TransactionAlreadyCompleted);
129        }
130        let writer_lock = self.db.writer.as_ref().ok_or_else(|| UniError::ReadOnly {
131            operation: "rollback".to_string(),
132        })?;
133        let mut writer = writer_lock.write().await;
134        writer.rollback_transaction()?;
135        self.completed = true;
136        let duration = self.start_time.elapsed();
137        tracing::Span::current().record("duration_ms", duration.as_millis());
138        metrics::histogram!("uni_transaction_duration_seconds").record(duration.as_secs_f64());
139        metrics::counter!("uni_transaction_rollbacks_total").increment(1);
140        info!("Transaction rolled back");
141        Ok(())
142    }
143}
144
145impl Drop for Transaction<'_> {
146    fn drop(&mut self) {
147        if !self.completed {
148            warn!(
149                transaction_id = %self.id,
150                "Transaction dropped without commit or rollback — auto-rolling back"
151            );
152            if let Some(writer_lock) = self.db.writer.as_ref() {
153                // try_write() is non-blocking — safe in synchronous Drop
154                match writer_lock.try_write() {
155                    Ok(mut writer) => writer.force_rollback(),
156                    Err(_) => error!(
157                        transaction_id = %self.id,
158                        "Could not acquire writer lock for auto-rollback"
159                    ),
160                }
161            }
162        }
163    }
164}
165
166/// Builder for parameterized mutations within a transaction.
167pub struct TransactionQueryBuilder<'a> {
168    tx: &'a Transaction<'a>,
169    cypher: String,
170    params: HashMap<String, Value>,
171}
172
173impl<'a> TransactionQueryBuilder<'a> {
174    /// Bind a parameter to the mutation.
175    pub fn param(mut self, name: &str, value: impl Into<Value>) -> Self {
176        self.params.insert(name.to_string(), value.into());
177        self
178    }
179
180    /// Execute the mutation and return affected row count.
181    pub async fn execute(self) -> Result<ExecuteResult> {
182        let before = self.tx.db.get_mutation_count().await;
183        let result = self
184            .tx
185            .db
186            .execute_internal(&self.cypher, self.params)
187            .await?;
188        let affected_rows = if result.is_empty() {
189            self.tx.db.get_mutation_count().await.saturating_sub(before)
190        } else {
191            result.len()
192        };
193        Ok(ExecuteResult { affected_rows })
194    }
195}
196
197impl Uni {
198    pub async fn begin(&self) -> Result<Transaction<'_>> {
199        Transaction::new(self).await
200    }
201
202    pub async fn transaction<'a, F, T>(&'a self, f: F) -> Result<T>
203    where
204        F: for<'b> FnOnce(&'b mut Transaction<'a>) -> BoxFuture<'b, Result<T>>,
205    {
206        let mut tx = self.begin().await?;
207
208        match f(&mut tx).await {
209            Ok(v) => match tx.commit().await {
210                Ok(_) => Ok(v),
211                Err(uni_common::UniError::TransactionAlreadyCompleted) => Ok(v),
212                Err(e) => Err(e),
213            },
214            Err(e) => {
215                // Ignore rollback error if it fails, but log it
216                if let Err(rollback_err) = tx.rollback().await {
217                    error!(
218                        "Transaction rollback failed during error recovery: {}",
219                        rollback_err
220                    );
221                }
222                Err(e)
223            }
224        }
225    }
226}