rust_logic_graph/multi_db/
transaction.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4use tracing::{error, info, warn};
5
6use crate::error::{ErrorContext, RustLogicGraphError};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum TransactionState {
11 Initiated,
13 Prepared,
15 Committed,
17 Aborted,
19}
20
21#[derive(Debug, Clone)]
23pub struct TransactionParticipant {
24 pub id: String,
25 pub database: String,
26 pub state: TransactionState,
27}
28
29#[derive(Clone)]
60pub struct DistributedTransaction {
61 pub id: String,
62 pub participants: Vec<TransactionParticipant>,
63 pub state: TransactionState,
64 pub metadata: HashMap<String, String>,
65}
66
67impl DistributedTransaction {
68 pub fn new(id: impl Into<String>) -> Self {
70 let txn_id = id.into();
71 info!(
72 "đ Distributed Transaction: Creating transaction '{}'",
73 txn_id
74 );
75
76 Self {
77 id: txn_id,
78 participants: Vec::new(),
79 state: TransactionState::Initiated,
80 metadata: HashMap::new(),
81 }
82 }
83
84 pub fn add_participant(
86 &mut self,
87 database: impl Into<String>,
88 id: impl Into<String>,
89 ) -> &mut Self {
90 let participant = TransactionParticipant {
91 id: id.into(),
92 database: database.into(),
93 state: TransactionState::Initiated,
94 };
95
96 info!(
97 " â Adding participant: {} ({})",
98 participant.id, participant.database
99 );
100 self.participants.push(participant);
101 self
102 }
103
104 pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
106 self.metadata.insert(key.into(), value.into());
107 self
108 }
109
110 pub async fn prepare(&mut self) -> Result<bool, RustLogicGraphError> {
118 info!(
119 "đ Transaction '{}': PREPARE phase starting ({} participants)",
120 self.id,
121 self.participants.len()
122 );
123
124 if self.state != TransactionState::Initiated {
125 return Err(RustLogicGraphError::configuration_error(format!(
126 "Cannot prepare transaction in state {:?}",
127 self.state
128 )));
129 }
130
131 let mut all_prepared = true;
132
133 let participant_count = self.participants.len();
135 for i in 0..participant_count {
136 info!(
137 " đ Preparing participant: {} ({})",
138 self.participants[i].id, self.participants[i].database
139 );
140
141 let participant = &self.participants[i];
144 let prepared = self.simulate_prepare(participant).await?;
145
146 if prepared {
147 self.participants[i].state = TransactionState::Prepared;
148 info!(
149 " â
Participant {} prepared successfully",
150 self.participants[i].id
151 );
152 } else {
153 self.participants[i].state = TransactionState::Aborted;
154 warn!(
155 " â Participant {} failed to prepare",
156 self.participants[i].id
157 );
158 all_prepared = false;
159 break;
160 }
161 }
162
163 if all_prepared {
164 self.state = TransactionState::Prepared;
165 info!("â
Transaction '{}': All participants prepared", self.id);
166 } else {
167 self.state = TransactionState::Aborted;
168 warn!("â ī¸ Transaction '{}': Prepare phase failed", self.id);
169 }
170
171 Ok(all_prepared)
172 }
173
174 pub fn can_commit(&self) -> bool {
176 self.state == TransactionState::Prepared
177 && self
178 .participants
179 .iter()
180 .all(|p| p.state == TransactionState::Prepared)
181 }
182
183 pub async fn commit(&mut self) -> Result<(), RustLogicGraphError> {
185 info!("đ Transaction '{}': COMMIT phase starting", self.id);
186
187 if !self.can_commit() {
188 return Err(RustLogicGraphError::configuration_error(format!(
189 "Cannot commit transaction in state {:?}",
190 self.state
191 )));
192 }
193
194 let participant_count = self.participants.len();
196 for i in 0..participant_count {
197 info!(
198 " đž Committing participant: {} ({})",
199 self.participants[i].id, self.participants[i].database
200 );
201
202 let participant = &self.participants[i];
204 self.simulate_commit(participant).await?;
205
206 self.participants[i].state = TransactionState::Committed;
207 info!(" â
Participant {} committed", self.participants[i].id);
208 }
209
210 self.state = TransactionState::Committed;
211 info!("â
Transaction '{}': Successfully committed", self.id);
212
213 Ok(())
214 }
215
216 pub async fn abort(&mut self) -> Result<(), RustLogicGraphError> {
218 warn!("đ Transaction '{}': ABORT phase starting", self.id);
219
220 let participant_count = self.participants.len();
222 for i in 0..participant_count {
223 if self.participants[i].state == TransactionState::Prepared {
224 warn!(
225 " âŠī¸ Rolling back participant: {} ({})",
226 self.participants[i].id, self.participants[i].database
227 );
228
229 let participant = &self.participants[i];
231 self.simulate_rollback(participant).await?;
232
233 self.participants[i].state = TransactionState::Aborted;
234 warn!(" â
Participant {} rolled back", self.participants[i].id);
235 }
236 }
237
238 self.state = TransactionState::Aborted;
239 warn!("â ī¸ Transaction '{}': Aborted and rolled back", self.id);
240
241 Ok(())
242 }
243
244 async fn simulate_prepare(
247 &self,
248 _participant: &TransactionParticipant,
249 ) -> Result<bool, RustLogicGraphError> {
250 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
252 Ok(true) }
254
255 async fn simulate_commit(
256 &self,
257 _participant: &TransactionParticipant,
258 ) -> Result<(), RustLogicGraphError> {
259 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
260 Ok(())
261 }
262
263 async fn simulate_rollback(
264 &self,
265 _participant: &TransactionParticipant,
266 ) -> Result<(), RustLogicGraphError> {
267 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
268 Ok(())
269 }
270}
271
272pub struct TransactionCoordinator {
274 transactions: Arc<Mutex<HashMap<String, DistributedTransaction>>>,
275}
276
277impl TransactionCoordinator {
278 pub fn new() -> Self {
280 Self {
281 transactions: Arc::new(Mutex::new(HashMap::new())),
282 }
283 }
284
285 pub async fn begin(&self, txn_id: impl Into<String>) -> Result<String, RustLogicGraphError> {
287 let id = txn_id.into();
288 let txn = DistributedTransaction::new(id.clone());
289
290 let mut txns = self.transactions.lock().await;
291 if txns.contains_key(&id) {
292 return Err(RustLogicGraphError::configuration_error(format!(
293 "Transaction '{}' already exists",
294 id
295 )));
296 }
297
298 txns.insert(id.clone(), txn);
299 Ok(id)
300 }
301
302 pub async fn get(&self, txn_id: &str) -> Result<DistributedTransaction, RustLogicGraphError> {
304 let txns = self.transactions.lock().await;
305 txns.get(txn_id).cloned().ok_or_else(|| {
306 RustLogicGraphError::configuration_error(format!("Transaction '{}' not found", txn_id))
307 })
308 }
309
310 pub async fn update(&self, txn: DistributedTransaction) -> Result<(), RustLogicGraphError> {
312 let mut txns = self.transactions.lock().await;
313 txns.insert(txn.id.clone(), txn);
314 Ok(())
315 }
316
317 pub async fn remove(&self, txn_id: &str) -> Result<(), RustLogicGraphError> {
319 let mut txns = self.transactions.lock().await;
320 txns.remove(txn_id);
321 Ok(())
322 }
323
324 pub async fn active_transactions(&self) -> Vec<String> {
326 let txns = self.transactions.lock().await;
327 txns.keys().cloned().collect()
328 }
329}
330
331impl Default for TransactionCoordinator {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[tokio::test]
342 async fn test_transaction_lifecycle() {
343 let mut txn = DistributedTransaction::new("test_txn");
344
345 txn.add_participant("db1", "op1");
346 txn.add_participant("db2", "op2");
347
348 let prepared = txn.prepare().await.unwrap();
350 assert!(prepared);
351 assert_eq!(txn.state, TransactionState::Prepared);
352
353 txn.commit().await.unwrap();
355 assert_eq!(txn.state, TransactionState::Committed);
356 }
357
358 #[tokio::test]
359 async fn test_transaction_abort() {
360 let mut txn = DistributedTransaction::new("test_txn");
361
362 txn.add_participant("db1", "op1");
363
364 txn.prepare().await.unwrap();
366
367 txn.abort().await.unwrap();
369 assert_eq!(txn.state, TransactionState::Aborted);
370 }
371
372 #[tokio::test]
373 async fn test_coordinator() {
374 let coordinator = TransactionCoordinator::new();
375
376 let txn_id = coordinator.begin("txn1").await.unwrap();
377 assert_eq!(txn_id, "txn1");
378
379 let txn = coordinator.get(&txn_id).await.unwrap();
380 assert_eq!(txn.id, "txn1");
381
382 let active = coordinator.active_transactions().await;
383 assert_eq!(active.len(), 1);
384 }
385}