1use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use std::time::{Duration, Instant};
7use tracing::{info, debug};
8
9use crate::storage::{EventLogEntry, OperationRecord, StateRecord, Storage};
10use crate::types::*;
11
12#[derive(Debug, Clone)]
14struct Transaction {
15 txn_id: TxnId,
16 created_at: Instant,
17 timeout: Duration,
18 operations: Vec<StagedOperation>,
19}
20
21#[derive(Debug, Clone)]
22enum StagedOperation {
23 Write {
24 namespace: Namespace,
25 agent_id: AgentId,
26 key: Key,
27 value: serde_json::Value,
28 },
29 Delete {
30 namespace: Namespace,
31 agent_id: AgentId,
32 key: Key,
33 },
34}
35
36#[derive(Debug)]
38pub enum Command {
39 BeginTransaction {
40 timeout_ms: Option<u64>,
41 },
42 Write {
43 txn_id: TxnId,
44 namespace: Namespace,
45 agent_id: AgentId,
46 key: Key,
47 value: serde_json::Value,
48 },
49 Delete {
50 txn_id: TxnId,
51 namespace: Namespace,
52 agent_id: AgentId,
53 key: Key,
54 },
55 Commit {
56 txn_id: TxnId,
57 },
58 Abort {
59 txn_id: TxnId,
60 },
61}
62
63pub struct StateMachine {
66 storage: Arc<dyn Storage>,
67 transactions: Arc<RwLock<HashMap<TxnId, Transaction>>>,
68 version_counters: Arc<RwLock<HashMap<RecordId, Version>>>,
69 commits_since_snapshot: Arc<RwLock<u64>>,
70}
71
72impl StateMachine {
73 pub fn new(storage: Arc<dyn Storage>) -> Self {
74 Self {
75 storage,
76 transactions: Arc::new(RwLock::new(HashMap::new())),
77 version_counters: Arc::new(RwLock::new(HashMap::new())),
78 commits_since_snapshot: Arc::new(RwLock::new(0)),
79 }
80 }
81
82 pub fn begin_transaction(&self, timeout_ms: Option<u64>) -> Result<TxnId> {
84 let txn_id = uuid::Uuid::new_v4().to_string();
85 let timeout = Duration::from_millis(timeout_ms.unwrap_or(30000));
86
87 let txn = Transaction {
88 txn_id: txn_id.clone(),
89 created_at: Instant::now(),
90 timeout,
91 operations: Vec::new(),
92 };
93
94 let mut transactions = self.transactions.write().unwrap();
95 transactions.insert(txn_id.clone(), txn);
96
97 debug!("Transaction started: txn_id={}", txn_id);
98 Ok(txn_id)
99 }
100
101 pub fn write(&self, txn_id: &str, namespace: String, agent_id: String, key: String, value: serde_json::Value) -> Result<()> {
103 let mut transactions = self.transactions.write().unwrap();
104 let txn = transactions.get_mut(txn_id).ok_or_else(|| anyhow!("Transaction not found"))?;
105
106 if txn.created_at.elapsed() > txn.timeout {
108 transactions.remove(txn_id);
109 return Err(anyhow!("Transaction expired"));
110 }
111
112 txn.operations.push(StagedOperation::Write {
113 namespace,
114 agent_id,
115 key,
116 value,
117 });
118
119 Ok(())
120 }
121
122 pub fn delete(&self, txn_id: &str, namespace: String, agent_id: String, key: String) -> Result<()> {
124 let mut transactions = self.transactions.write().unwrap();
125 let txn = transactions.get_mut(txn_id).ok_or_else(|| anyhow!("Transaction not found"))?;
126
127 if txn.created_at.elapsed() > txn.timeout {
129 transactions.remove(txn_id);
130 return Err(anyhow!("Transaction expired"));
131 }
132
133 txn.operations.push(StagedOperation::Delete {
134 namespace,
135 agent_id,
136 key,
137 });
138
139 Ok(())
140 }
141
142 pub fn commit(&self, txn_id: &str) -> Result<CommitTs> {
144 use tracing::{info, debug};
145
146 debug!(txn_id = %txn_id, "Committing transaction");
147
148 let txn = {
150 let mut transactions = self.transactions.write().unwrap();
151 transactions.remove(txn_id).ok_or_else(|| anyhow!("Transaction not found"))?
152 };
153
154 if txn.created_at.elapsed() > txn.timeout {
156 debug!(txn_id = %txn_id, "Transaction expired");
157 return Err(anyhow!("Transaction expired"));
158 }
159
160 let commit_ts = self.storage.next_commit_ts()?;
162
163 let mut operation_records = Vec::new();
165 let mut version_counters = self.version_counters.write().unwrap();
166
167 for op in txn.operations {
168 match op {
169 StagedOperation::Write { namespace, agent_id, key, value } => {
170 let record_id = RecordId::new(namespace.clone(), agent_id.clone(), key.clone());
171
172 let version = version_counters.entry(record_id.clone()).or_insert(0);
174 *version += 1;
175 let current_version = *version;
176
177 let record = StateRecord {
179 namespace: namespace.clone(),
180 agent_id: agent_id.clone(),
181 key: key.clone(),
182 value: Some(value.clone()),
183 version: current_version,
184 commit_ts,
185 deleted: false,
186 };
187 self.storage.write_state(record)?;
188
189 operation_records.push(OperationRecord {
191 namespace,
192 agent_id,
193 key,
194 value: Some(value),
195 version: current_version,
196 });
197 }
198 StagedOperation::Delete { namespace, agent_id, key } => {
199 let record_id = RecordId::new(namespace.clone(), agent_id.clone(), key.clone());
200
201 let version = version_counters.entry(record_id.clone()).or_insert(0);
203 *version += 1;
204 let current_version = *version;
205
206 let record = StateRecord {
208 namespace: namespace.clone(),
209 agent_id: agent_id.clone(),
210 key: key.clone(),
211 value: None,
212 version: current_version,
213 commit_ts,
214 deleted: true,
215 };
216 self.storage.write_state(record)?;
217
218 operation_records.push(OperationRecord {
220 namespace,
221 agent_id,
222 key,
223 value: None,
224 version: current_version,
225 });
226 }
227 }
228 }
229
230 let event = EventLogEntry {
232 txn_id: txn.txn_id.clone(),
233 commit_ts,
234 operations: operation_records.clone(),
235 };
236 self.storage.append_event(event)?;
237
238 self.storage.flush()?;
240
241 info!(
243 txn_id = %txn_id,
244 commit_ts = commit_ts,
245 operations = operation_records.len(),
246 "Transaction committed"
247 );
248
249 Ok(commit_ts)
250 }
251
252 pub fn abort(&self, txn_id: &str) -> Result<()> {
254 use tracing::debug;
255
256 let mut transactions = self.transactions.write().unwrap();
257 if transactions.remove(txn_id).is_some() {
258 debug!(txn_id = %txn_id, "Transaction aborted");
259 }
260 Ok(())
261 }
262
263 pub fn get_state(&self, namespace: &str, agent_id: &str, key: &str) -> Result<Option<StateRecord>> {
265 let record_id = RecordId::new(namespace.to_string(), agent_id.to_string(), key.to_string());
266 self.storage.read_state(&record_id)
267 }
268
269 pub fn get_state_at_version(&self, namespace: &str, agent_id: &str, key: &str, version: Version) -> Result<Option<StateRecord>> {
271 let record_id = RecordId::new(namespace.to_string(), agent_id.to_string(), key.to_string());
272 self.storage.read_state_at_version(&record_id, version)
273 }
274
275 pub fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>> {
277 self.storage.list_keys(namespace, agent_id)
278 }
279
280 pub fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>> {
282 self.storage.scan_prefix(namespace, agent_id, prefix)
283 }
284
285 pub fn replay(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>> {
287 info!(
288 namespace = %namespace,
289 agent_id = %agent_id,
290 start_ts = ?start_ts,
291 end_ts = ?end_ts,
292 "Replay started"
293 );
294
295 let events = self.storage.replay_events(namespace, agent_id, start_ts, end_ts)?;
296 let event_count = events.len();
297
298 info!(
299 namespace = %namespace,
300 agent_id = %agent_id,
301 event_count = event_count,
302 "Replay completed"
303 );
304
305 Ok(events)
306 }
307
308 pub fn cleanup_expired_transactions(&self) {
310 let mut transactions = self.transactions.write().unwrap();
311 transactions.retain(|_, txn| txn.created_at.elapsed() <= txn.timeout);
312 }
313
314 pub fn create_snapshot(&self) -> Result<()> {
316 let snapshot = self.storage.create_snapshot()?;
317 self.storage.save_snapshot(&snapshot)?;
318
319 let mut counter = self.commits_since_snapshot.write().unwrap();
321 *counter = 0;
322
323 Ok(())
324 }
325
326 pub fn maybe_snapshot(&self, snapshot_interval: u64) -> Result<()> {
328 let mut counter = self.commits_since_snapshot.write().unwrap();
329 *counter += 1;
330
331 if *counter >= snapshot_interval {
332 drop(counter); self.create_snapshot()?;
334 }
335
336 Ok(())
337 }
338
339 pub fn recover_from_snapshot(&self, snapshot: &crate::storage::Snapshot) -> Result<()> {
341 let mut version_counters = self.version_counters.write().unwrap();
343 version_counters.clear();
344
345 for record in &snapshot.records {
346 let record_id = RecordId::new(
347 record.namespace.clone(),
348 record.agent_id.clone(),
349 record.key.clone(),
350 );
351 version_counters.insert(record_id, record.version);
352 }
353
354 Ok(())
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::storage::InMemoryStorage;
362
363 #[test]
364 fn test_transaction_lifecycle() {
365 let storage = Arc::new(InMemoryStorage::new());
366 let sm = StateMachine::new(storage);
367
368 let txn_id = sm.begin_transaction(None).unwrap();
370
371 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "key1".to_string(), serde_json::json!({"value": 42})).unwrap();
373
374 let commit_ts = sm.commit(&txn_id).unwrap();
376 assert!(commit_ts > 0);
377
378 let state = sm.get_state("default", "agent-1", "key1").unwrap();
380 assert!(state.is_some());
381 let record = state.unwrap();
382 assert_eq!(record.value.unwrap()["value"], 42);
383 }
384
385 #[test]
386 fn test_delete() {
387 let storage = Arc::new(InMemoryStorage::new());
388 let sm = StateMachine::new(storage);
389
390 let txn_id = sm.begin_transaction(None).unwrap();
392 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "key1".to_string(), serde_json::json!({"value": 42})).unwrap();
393 sm.commit(&txn_id).unwrap();
394
395 let txn_id = sm.begin_transaction(None).unwrap();
397 sm.delete(&txn_id, "default".to_string(), "agent-1".to_string(), "key1".to_string()).unwrap();
398 sm.commit(&txn_id).unwrap();
399
400 let state = sm.get_state("default", "agent-1", "key1").unwrap();
402 assert!(state.is_some());
403 assert!(state.unwrap().deleted);
404 }
405
406 #[test]
407 fn test_versioning() {
408 let storage = Arc::new(InMemoryStorage::new());
409 let sm = StateMachine::new(storage);
410
411 let txn_id = sm.begin_transaction(None).unwrap();
413 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "key1".to_string(), serde_json::json!({"value": 1})).unwrap();
414 sm.commit(&txn_id).unwrap();
415
416 let txn_id = sm.begin_transaction(None).unwrap();
418 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "key1".to_string(), serde_json::json!({"value": 2})).unwrap();
419 sm.commit(&txn_id).unwrap();
420
421 let state_v1 = sm.get_state_at_version("default", "agent-1", "key1", 1).unwrap();
423 assert_eq!(state_v1.unwrap().value.unwrap()["value"], 1);
424
425 let state_v2 = sm.get_state_at_version("default", "agent-1", "key1", 2).unwrap();
427 assert_eq!(state_v2.unwrap().value.unwrap()["value"], 2);
428
429 let state_latest = sm.get_state("default", "agent-1", "key1").unwrap();
431 assert_eq!(state_latest.unwrap().value.unwrap()["value"], 2);
432 }
433
434 #[test]
435 fn test_read_after_write() {
436 let storage = Arc::new(InMemoryStorage::new());
437 let sm = StateMachine::new(storage);
438
439 let txn_id = sm.begin_transaction(None).unwrap();
441 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "counter".to_string(), serde_json::json!(1)).unwrap();
442 sm.commit(&txn_id).unwrap();
443
444 let state = sm.get_state("default", "agent-1", "counter").unwrap();
446 assert!(state.is_some());
447 assert_eq!(state.unwrap().value.unwrap(), serde_json::json!(1));
448 }
449
450 #[test]
451 fn test_aborted_tx_has_no_effect() {
452 let storage = Arc::new(InMemoryStorage::new());
453 let sm = StateMachine::new(storage);
454
455 let txn_id = sm.begin_transaction(None).unwrap();
457
458 sm.write(&txn_id, "default".to_string(), "agent-1".to_string(), "temp".to_string(), serde_json::json!(42)).unwrap();
460
461 sm.abort(&txn_id).unwrap();
463
464 let state = sm.get_state("default", "agent-1", "temp").unwrap();
466 assert!(state.is_none());
467 }
468
469 #[test]
470 fn test_concurrent_commits_serialize() {
471 use std::thread;
472
473 let storage = Arc::new(InMemoryStorage::new());
474 let sm = Arc::new(StateMachine::new(storage));
475
476 let mut handles = vec![];
477
478 for i in 0..10 {
480 let sm_clone = sm.clone();
481 let handle = thread::spawn(move || {
482 let txn_id = sm_clone.begin_transaction(None).unwrap();
483 sm_clone.write(
484 &txn_id,
485 "default".to_string(),
486 "agent-1".to_string(),
487 format!("key{}", i),
488 serde_json::json!(i),
489 ).unwrap();
490 sm_clone.commit(&txn_id).unwrap();
491 });
492 handles.push(handle);
493 }
494
495 for handle in handles {
497 handle.join().unwrap();
498 }
499
500 for i in 0..10 {
502 let state = sm.get_state("default", "agent-1", &format!("key{}", i)).unwrap();
503 assert!(state.is_some());
504 assert_eq!(state.unwrap().value.unwrap(), serde_json::json!(i));
505 }
506 }
507
508 #[test]
509 fn test_transaction_timeout() {
510 use std::time::Duration;
511 use std::thread;
512
513 let storage = Arc::new(InMemoryStorage::new());
514 let sm = StateMachine::new(storage);
515
516 let txn_id = sm.begin_transaction(Some(100)).unwrap(); thread::sleep(Duration::from_millis(150));
521
522 let result = sm.commit(&txn_id);
524 assert!(result.is_err());
525 assert!(result.unwrap_err().to_string().contains("expired"));
526 }
527
528 #[test]
529 fn test_list_keys_after_operations() {
530 let storage = Arc::new(InMemoryStorage::new());
531 let sm = StateMachine::new(storage);
532
533 for i in 1..=5 {
535 let txn_id = sm.begin_transaction(None).unwrap();
536 sm.write(
537 &txn_id,
538 "default".to_string(),
539 "agent-1".to_string(),
540 format!("key{}", i),
541 serde_json::json!(i),
542 ).unwrap();
543 sm.commit(&txn_id).unwrap();
544 }
545
546 let txn_id = sm.begin_transaction(None).unwrap();
548 sm.delete(&txn_id, "default".to_string(), "agent-1".to_string(), "key3".to_string()).unwrap();
549 sm.commit(&txn_id).unwrap();
550
551 let keys = sm.list_keys("default", "agent-1").unwrap();
553 assert_eq!(keys.len(), 4);
554 assert!(!keys.contains(&"key3".to_string()));
555 }
556
557 #[test]
558 fn test_replay_determinism() {
559 let storage = Arc::new(InMemoryStorage::new());
560 let sm = StateMachine::new(storage);
561
562 for i in 1..=5 {
564 let txn_id = sm.begin_transaction(None).unwrap();
565 sm.write(
566 &txn_id,
567 "default".to_string(),
568 "agent-1".to_string(),
569 format!("key{}", i),
570 serde_json::json!({"step": i}),
571 ).unwrap();
572 sm.commit(&txn_id).unwrap();
573 }
574
575 let events = sm.replay("default", "agent-1", None, None).unwrap();
577 assert_eq!(events.len(), 5);
578
579 for i in 1..events.len() {
581 assert!(events[i].commit_ts > events[i-1].commit_ts);
582 }
583 }
584
585 #[test]
586 fn test_create_snapshot() {
587 let storage = Arc::new(InMemoryStorage::new());
588 let sm = StateMachine::new(storage.clone());
589
590 for i in 1..=3 {
592 let txn_id = sm.begin_transaction(None).unwrap();
593 sm.write(
594 &txn_id,
595 "default".to_string(),
596 "agent-1".to_string(),
597 format!("key{}", i),
598 serde_json::json!({"value": i}),
599 ).unwrap();
600 sm.commit(&txn_id).unwrap();
601 }
602
603 let snapshot = storage.create_snapshot().unwrap();
605 assert_eq!(snapshot.metadata.version, 1);
606 assert_eq!(snapshot.metadata.record_count, 3);
607 assert_eq!(snapshot.records.len(), 3);
608 }
609
610 #[test]
611 fn test_snapshot_persistence_with_rocksdb() {
612 use tempfile::TempDir;
613 use crate::storage::RocksStorage;
614
615 let temp_dir = TempDir::new().unwrap();
617 let config = crate::storage::StorageConfig {
618 data_dir: temp_dir.path().to_path_buf(),
619 fsync_on_commit: true,
620 snapshot_interval: 10,
621 max_log_size: 1024 * 1024,
622 };
623
624 {
626 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
627 let sm = StateMachine::new(storage.clone());
628
629 for i in 1..=5 {
630 let txn_id = sm.begin_transaction(None).unwrap();
631 sm.write(
632 &txn_id,
633 "default".to_string(),
634 "agent-1".to_string(),
635 format!("key{}", i),
636 serde_json::json!({"value": i}),
637 ).unwrap();
638 sm.commit(&txn_id).unwrap();
639 }
640
641 sm.create_snapshot().unwrap();
643 }
644
645 {
647 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
648 let snapshot = storage.load_snapshot().unwrap();
649 assert!(snapshot.is_some());
650
651 let snapshot = snapshot.unwrap();
652 assert_eq!(snapshot.metadata.version, crate::storage::SNAPSHOT_VERSION);
653 assert_eq!(snapshot.metadata.record_count, 5);
654
655 let sm = StateMachine::new(storage);
656 sm.recover_from_snapshot(&snapshot).unwrap();
657
658 for i in 1..=5 {
660 let state = sm.get_state("default", "agent-1", &format!("key{}", i)).unwrap();
661 assert!(state.is_some());
662 assert_eq!(state.unwrap().value.unwrap()["value"], i);
663 }
664 }
665 }
666
667 #[test]
668 fn test_log_replay_after_snapshot() {
669 use tempfile::TempDir;
670 use crate::storage::RocksStorage;
671
672 let temp_dir = TempDir::new().unwrap();
674 let config = crate::storage::StorageConfig {
675 data_dir: temp_dir.path().to_path_buf(),
676 fsync_on_commit: true,
677 snapshot_interval: 3,
678 max_log_size: 1024 * 1024,
679 };
680
681 let snapshot_ts;
682
683 {
685 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
686 let sm = StateMachine::new(storage.clone());
687
688 for i in 1..=3 {
689 let txn_id = sm.begin_transaction(None).unwrap();
690 sm.write(
691 &txn_id,
692 "default".to_string(),
693 "agent-1".to_string(),
694 format!("before_{}", i),
695 serde_json::json!(i),
696 ).unwrap();
697 sm.commit(&txn_id).unwrap();
698 }
699
700 sm.create_snapshot().unwrap();
702 let snapshot = storage.create_snapshot().unwrap();
703 snapshot_ts = snapshot.metadata.snapshot_ts;
704 }
705
706 {
708 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
709 let sm = StateMachine::new(storage.clone());
710
711 for i in 1..=2 {
712 let txn_id = sm.begin_transaction(None).unwrap();
713 sm.write(
714 &txn_id,
715 "default".to_string(),
716 "agent-1".to_string(),
717 format!("after_{}", i),
718 serde_json::json!(i + 100),
719 ).unwrap();
720 sm.commit(&txn_id).unwrap();
721 }
722 }
723
724 {
726 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
727 let sm = StateMachine::new(storage.clone());
728
729 let snapshot = storage.load_snapshot().unwrap();
731 assert!(snapshot.is_some());
732 sm.recover_from_snapshot(&snapshot.unwrap()).unwrap();
733
734 for i in 1..=3 {
736 let state = sm.get_state("default", "agent-1", &format!("before_{}", i)).unwrap();
737 assert!(state.is_some());
738 }
739
740 for i in 1..=2 {
742 let state = sm.get_state("default", "agent-1", &format!("after_{}", i)).unwrap();
743 assert!(state.is_some());
744 assert_eq!(state.unwrap().value.unwrap(), serde_json::json!(i + 100));
745 }
746
747 let events = sm.replay("default", "agent-1", Some(snapshot_ts), None).unwrap();
749 assert!(events.len() >= 2);
751 }
752 }
753
754 #[test]
755 fn test_crash_recovery() {
756 use tempfile::TempDir;
757 use crate::storage::RocksStorage;
758
759 let temp_dir = TempDir::new().unwrap();
760 let config = crate::storage::StorageConfig {
761 data_dir: temp_dir.path().to_path_buf(),
762 fsync_on_commit: true,
763 snapshot_interval: 10,
764 max_log_size: 1024 * 1024,
765 };
766
767 {
769 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
770 let sm = StateMachine::new(storage);
771
772 let txn_id = sm.begin_transaction(None).unwrap();
773 sm.write(
774 &txn_id,
775 "default".to_string(),
776 "agent-1".to_string(),
777 "crash_test".to_string(),
778 serde_json::json!({"status": "committed"}),
779 ).unwrap();
780 sm.commit(&txn_id).unwrap();
781 }
782 {
786 let storage = Arc::new(RocksStorage::new(config.clone()).unwrap());
787 let sm = StateMachine::new(storage);
788
789 let state = sm.get_state("default", "agent-1", "crash_test").unwrap();
791 assert!(state.is_some());
792 assert_eq!(state.unwrap().value.unwrap()["status"], "committed");
793 }
794 }
795}