1use crate::operation::Operation;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8use ucm_core::{Error, Result};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct TransactionId(pub String);
13
14impl TransactionId {
15 pub fn generate() -> Self {
16 use chrono::Utc;
17 #[cfg(not(target_arch = "wasm32"))]
18 let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0);
19 #[cfg(target_arch = "wasm32")]
20 let ts = 0; Self(format!("txn_{:x}", ts))
22 }
23
24 pub fn named(name: impl Into<String>) -> Self {
25 Self(name.into())
26 }
27}
28
29impl std::fmt::Display for TransactionId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.0)
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum TransactionState {
38 Active,
39 Committed,
40 RolledBack,
41 TimedOut,
42}
43
44#[derive(Debug, Clone)]
46pub struct Transaction {
47 pub id: TransactionId,
49 pub name: Option<String>,
51 pub operations: Vec<Operation>,
53 pub savepoints: Vec<Savepoint>,
55 pub state: TransactionState,
57 #[cfg(not(target_arch = "wasm32"))]
59 pub started_at: Instant,
60 #[cfg(not(target_arch = "wasm32"))]
62 pub created_at: DateTime<Utc>,
63 pub timeout: Duration,
65}
66
67#[derive(Debug, Clone)]
69pub struct Savepoint {
70 pub name: String,
71 pub operation_index: usize,
72 #[cfg(not(target_arch = "wasm32"))]
73 pub created_at: DateTime<Utc>,
74}
75
76impl Transaction {
77 pub fn new(timeout: Duration) -> Self {
79 Self {
80 id: TransactionId::generate(),
81 name: None,
82 operations: Vec::new(),
83 savepoints: Vec::new(),
84 state: TransactionState::Active,
85 #[cfg(not(target_arch = "wasm32"))]
86 started_at: Instant::now(),
87 #[cfg(not(target_arch = "wasm32"))]
88 created_at: Utc::now(),
89 timeout,
90 }
91 }
92
93 pub fn named(name: impl Into<String>, timeout: Duration) -> Self {
95 let name = name.into();
96 Self {
97 id: TransactionId::named(&name),
98 name: Some(name),
99 operations: Vec::new(),
100 savepoints: Vec::new(),
101 state: TransactionState::Active,
102 #[cfg(not(target_arch = "wasm32"))]
103 started_at: Instant::now(),
104 #[cfg(not(target_arch = "wasm32"))]
105 created_at: Utc::now(),
106 timeout,
107 }
108 }
109
110 pub fn add_operation(&mut self, op: Operation) -> Result<()> {
112 if self.state != TransactionState::Active {
113 return Err(Error::Internal(format!(
114 "Cannot add operation to {:?} transaction",
115 self.state
116 )));
117 }
118 if self.is_timed_out() {
119 self.state = TransactionState::TimedOut;
120 return Err(Error::new(
121 ucm_core::ErrorCode::E301TransactionTimeout,
122 "Transaction timed out",
123 ));
124 }
125 self.operations.push(op);
126 Ok(())
127 }
128
129 pub fn savepoint(&mut self, name: impl Into<String>) {
131 self.savepoints.push(Savepoint {
132 name: name.into(),
133 operation_index: self.operations.len(),
134 #[cfg(not(target_arch = "wasm32"))]
135 created_at: Utc::now(),
136 });
137 }
138
139 pub fn is_timed_out(&self) -> bool {
141 #[cfg(not(target_arch = "wasm32"))]
142 return self.started_at.elapsed() > self.timeout;
143
144 #[cfg(target_arch = "wasm32")]
145 false
146 }
147
148 pub fn elapsed(&self) -> Duration {
150 #[cfg(not(target_arch = "wasm32"))]
151 return self.started_at.elapsed();
152
153 #[cfg(target_arch = "wasm32")]
154 Duration::from_secs(0)
155 }
156
157 pub fn operation_count(&self) -> usize {
159 self.operations.len()
160 }
161}
162
163#[derive(Debug, Default)]
165pub struct TransactionManager {
166 transactions: HashMap<TransactionId, Transaction>,
168 default_timeout: Duration,
170}
171
172impl TransactionManager {
173 pub fn new() -> Self {
174 Self {
175 transactions: HashMap::new(),
176 default_timeout: Duration::from_secs(30),
177 }
178 }
179
180 pub fn with_timeout(timeout: Duration) -> Self {
181 Self {
182 transactions: HashMap::new(),
183 default_timeout: timeout,
184 }
185 }
186
187 pub fn begin(&mut self) -> TransactionId {
189 let txn = Transaction::new(self.default_timeout);
190 let id = txn.id.clone();
191 self.transactions.insert(id.clone(), txn);
192 id
193 }
194
195 pub fn begin_named(&mut self, name: impl Into<String>) -> TransactionId {
197 let txn = Transaction::named(name, self.default_timeout);
198 let id = txn.id.clone();
199 self.transactions.insert(id.clone(), txn);
200 id
201 }
202
203 pub fn get(&self, id: &TransactionId) -> Option<&Transaction> {
205 self.transactions.get(id)
206 }
207
208 pub fn get_mut(&mut self, id: &TransactionId) -> Option<&mut Transaction> {
210 self.transactions.get_mut(id)
211 }
212
213 pub fn add_operation(&mut self, id: &TransactionId, op: Operation) -> Result<()> {
215 let txn = self.transactions.get_mut(id).ok_or_else(|| {
216 Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
217 })?;
218 txn.add_operation(op)
219 }
220
221 pub fn commit(&mut self, id: &TransactionId) -> Result<Vec<Operation>> {
223 let txn = self.transactions.get_mut(id).ok_or_else(|| {
224 Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
225 })?;
226
227 if txn.state != TransactionState::Active {
228 return Err(Error::Internal(format!(
229 "Cannot commit {:?} transaction",
230 txn.state
231 )));
232 }
233
234 if txn.is_timed_out() {
235 txn.state = TransactionState::TimedOut;
236 return Err(Error::new(
237 ucm_core::ErrorCode::E301TransactionTimeout,
238 "Transaction timed out",
239 ));
240 }
241
242 txn.state = TransactionState::Committed;
243 Ok(txn.operations.clone())
244 }
245
246 pub fn rollback(&mut self, id: &TransactionId) -> Result<()> {
248 let txn = self.transactions.get_mut(id).ok_or_else(|| {
249 Error::new(ucm_core::ErrorCode::E303TransactionNotFound, id.to_string())
250 })?;
251
252 if txn.state != TransactionState::Active {
253 return Err(Error::Internal(format!(
254 "Cannot rollback {:?} transaction",
255 txn.state
256 )));
257 }
258
259 txn.state = TransactionState::RolledBack;
260 Ok(())
261 }
262
263 pub fn cleanup(&mut self) {
265 self.transactions
266 .retain(|_, txn| txn.state == TransactionState::Active && !txn.is_timed_out());
267 }
268
269 pub fn active_count(&self) -> usize {
271 self.transactions
272 .values()
273 .filter(|t| t.state == TransactionState::Active)
274 .count()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::operation::PruneCondition;
282
283 #[test]
284 fn test_transaction_lifecycle() {
285 let mut mgr = TransactionManager::new();
286
287 let id = mgr.begin();
288 assert_eq!(mgr.active_count(), 1);
289
290 mgr.add_operation(
291 &id,
292 Operation::Prune {
293 condition: Some(PruneCondition::Unreachable),
294 },
295 )
296 .unwrap();
297
298 let ops = mgr.commit(&id).unwrap();
299 assert_eq!(ops.len(), 1);
300 }
301
302 #[test]
303 fn test_named_transaction() {
304 let mut mgr = TransactionManager::new();
305
306 let id = mgr.begin_named("my-transaction");
307 assert_eq!(id.0, "my-transaction");
308 }
309
310 #[test]
311 fn test_rollback() {
312 let mut mgr = TransactionManager::new();
313
314 let id = mgr.begin();
315 mgr.rollback(&id).unwrap();
316
317 let txn = mgr.get(&id).unwrap();
318 assert_eq!(txn.state, TransactionState::RolledBack);
319 }
320
321 #[test]
322 fn test_timeout() {
323 let mut mgr = TransactionManager::with_timeout(Duration::from_millis(1));
324
325 let id = mgr.begin();
326 std::thread::sleep(Duration::from_millis(10));
327
328 let result = mgr.commit(&id);
329 assert!(result.is_err());
330 }
331}