Skip to main content

statehouse_core/
storage.rs

1// Storage trait and implementations
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8use crate::types::*;
9
10/// Snapshot format version for compatibility
11pub const SNAPSHOT_VERSION: u32 = 1;
12
13/// Storage configuration
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StorageConfig {
16    /// Data directory for persistent storage
17    pub data_dir: PathBuf,
18    /// Enable fsync on commit (slower but safer)
19    pub fsync_on_commit: bool,
20    /// Snapshot interval (number of commits)
21    pub snapshot_interval: u64,
22    /// Max log size before compaction (bytes)
23    pub max_log_size: u64,
24}
25
26impl Default for StorageConfig {
27    fn default() -> Self {
28        Self {
29            data_dir: PathBuf::from("./data"),
30            fsync_on_commit: true,
31            snapshot_interval: 1000,
32            max_log_size: 100 * 1024 * 1024, // 100MB
33        }
34    }
35}
36
37/// State record stored in the database
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StateRecord {
40    pub namespace: Namespace,
41    pub agent_id: AgentId,
42    pub key: Key,
43    pub value: Option<serde_json::Value>,
44    pub version: Version,
45    pub commit_ts: CommitTs,
46    pub deleted: bool,
47}
48
49/// Event log entry
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct EventLogEntry {
52    pub txn_id: TxnId,
53    pub commit_ts: CommitTs,
54    pub operations: Vec<OperationRecord>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct OperationRecord {
59    pub namespace: Namespace,
60    pub agent_id: AgentId,
61    pub key: Key,
62    pub value: Option<serde_json::Value>,
63    pub version: Version,
64}
65
66/// Snapshot metadata
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SnapshotMetadata {
69    /// Snapshot format version
70    pub version: u32,
71    /// Commit timestamp when snapshot was created
72    pub snapshot_ts: CommitTs,
73    /// Number of state records in snapshot
74    pub record_count: usize,
75    /// Timestamp when snapshot was created (system time)
76    pub created_at: u64,
77}
78
79/// Complete snapshot of system state
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Snapshot {
82    pub metadata: SnapshotMetadata,
83    pub records: Vec<StateRecord>,
84}
85
86/// Storage abstraction for Statehouse
87pub trait Storage: Send + Sync {
88    /// Health check
89    fn health_check(&self) -> Result<()>;
90
91    /// Write a state record
92    fn write_state(&self, record: StateRecord) -> Result<()>;
93
94    /// Read a state record
95    fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>>;
96
97    /// Read state at specific version
98    fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>>;
99
100    /// List all keys for an agent
101    fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>>;
102
103    /// Scan keys with prefix
104    fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>>;
105
106    /// Append event to log
107    fn append_event(&self, event: EventLogEntry) -> Result<()>;
108
109    /// Replay events for an agent
110    fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>>;
111
112    /// Get next commit timestamp
113    fn next_commit_ts(&self) -> Result<CommitTs>;
114
115    /// Flush writes to disk
116    fn flush(&self) -> Result<()>;
117
118    /// Create a snapshot of current state
119    fn create_snapshot(&self) -> Result<Snapshot>;
120
121    /// Save snapshot to disk
122    fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()>;
123
124    /// Load latest snapshot from disk
125    fn load_snapshot(&self) -> Result<Option<Snapshot>>;
126
127    /// Get all state records (for snapshotting)
128    fn get_all_state(&self) -> Result<Vec<StateRecord>>;
129}
130
131// ============================================================================
132// In-Memory Storage (for tests)
133// ============================================================================
134
135use std::sync::{Arc, RwLock};
136
137pub struct InMemoryStorage {
138    state: Arc<RwLock<HashMap<RecordId, Vec<StateRecord>>>>,
139    events: Arc<RwLock<Vec<EventLogEntry>>>,
140    commit_ts_counter: Arc<RwLock<CommitTs>>,
141}
142
143impl InMemoryStorage {
144    pub fn new() -> Self {
145        Self {
146            state: Arc::new(RwLock::new(HashMap::new())),
147            events: Arc::new(RwLock::new(Vec::new())),
148            commit_ts_counter: Arc::new(RwLock::new(0)),
149        }
150    }
151}
152
153impl Default for InMemoryStorage {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl Storage for InMemoryStorage {
160    fn health_check(&self) -> Result<()> {
161        Ok(())
162    }
163
164    fn write_state(&self, record: StateRecord) -> Result<()> {
165        let mut state = self.state.write().unwrap();
166        let record_id = RecordId::new(
167            record.namespace.clone(),
168            record.agent_id.clone(),
169            record.key.clone(),
170        );
171        state.entry(record_id).or_insert_with(Vec::new).push(record);
172        Ok(())
173    }
174
175    fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>> {
176        let state = self.state.read().unwrap();
177        Ok(state.get(record_id).and_then(|versions| versions.last().cloned()))
178    }
179
180    fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>> {
181        let state = self.state.read().unwrap();
182        Ok(state.get(record_id).and_then(|versions| {
183            versions.iter().find(|r| r.version == version).cloned()
184        }))
185    }
186
187    fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>> {
188        let state = self.state.read().unwrap();
189        let keys: Vec<String> = state
190            .iter()
191            .filter(|(id, versions)| {
192                id.namespace == namespace
193                    && id.agent_id == agent_id
194                    && versions.last().map(|r| !r.deleted).unwrap_or(false)
195            })
196            .map(|(id, _)| id.key.clone())
197            .collect();
198        Ok(keys)
199    }
200
201    fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>> {
202        let state = self.state.read().unwrap();
203        let records: Vec<StateRecord> = state
204            .iter()
205            .filter(|(id, _)| {
206                id.namespace == namespace
207                    && id.agent_id == agent_id
208                    && id.key.starts_with(prefix)
209            })
210            .filter_map(|(_, versions)| versions.last().cloned())
211            .filter(|r| !r.deleted)
212            .collect();
213        Ok(records)
214    }
215
216    fn append_event(&self, event: EventLogEntry) -> Result<()> {
217        let mut events = self.events.write().unwrap();
218        events.push(event);
219        Ok(())
220    }
221
222    fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>> {
223        let events = self.events.read().unwrap();
224        let filtered: Vec<EventLogEntry> = events
225            .iter()
226            .filter(|e| {
227                e.operations.iter().any(|op| {
228                    op.namespace == namespace && op.agent_id == agent_id
229                })
230            })
231            .filter(|e| {
232                if let Some(start) = start_ts {
233                    e.commit_ts >= start
234                } else {
235                    true
236                }
237            })
238            .filter(|e| {
239                if let Some(end) = end_ts {
240                    e.commit_ts <= end
241                } else {
242                    true
243                }
244            })
245            .cloned()
246            .collect();
247        Ok(filtered)
248    }
249
250    fn next_commit_ts(&self) -> Result<CommitTs> {
251        let mut counter = self.commit_ts_counter.write().unwrap();
252        *counter += 1;
253        Ok(*counter)
254    }
255
256    fn flush(&self) -> Result<()> {
257        Ok(())
258    }
259
260    fn create_snapshot(&self) -> Result<Snapshot> {
261        let state = self.state.read().unwrap();
262        let commit_ts_counter = self.commit_ts_counter.read().unwrap();
263        
264        // Collect all latest state records
265        let mut records = Vec::new();
266        for versions in state.values() {
267            if let Some(record) = versions.last() {
268                records.push(record.clone());
269            }
270        }
271
272        let metadata = SnapshotMetadata {
273            version: SNAPSHOT_VERSION,
274            snapshot_ts: *commit_ts_counter,
275            record_count: records.len(),
276            created_at: std::time::SystemTime::now()
277                .duration_since(std::time::UNIX_EPOCH)
278                .unwrap()
279                .as_secs(),
280        };
281
282        Ok(Snapshot { metadata, records })
283    }
284
285    fn save_snapshot(&self, _snapshot: &Snapshot) -> Result<()> {
286        // In-memory storage doesn't persist snapshots
287        Ok(())
288    }
289
290    fn load_snapshot(&self) -> Result<Option<Snapshot>> {
291        // In-memory storage doesn't persist snapshots
292        Ok(None)
293    }
294
295    fn get_all_state(&self) -> Result<Vec<StateRecord>> {
296        let state = self.state.read().unwrap();
297        let mut records = Vec::new();
298        for versions in state.values() {
299            if let Some(record) = versions.last() {
300                records.push(record.clone());
301            }
302        }
303        Ok(records)
304    }
305}
306
307// ============================================================================
308// RocksDB Storage
309// ============================================================================
310
311use rocksdb::{Options, DB};
312
313pub struct RocksStorage {
314    db: Arc<DB>,
315    config: StorageConfig,
316    commit_ts_counter: Arc<RwLock<CommitTs>>,
317}
318
319impl RocksStorage {
320    pub fn new(config: StorageConfig) -> Result<Self> {
321        std::fs::create_dir_all(&config.data_dir)?;
322
323        let mut opts = Options::default();
324        opts.create_if_missing(true);
325        opts.create_missing_column_families(true);
326
327        let db_path = config.data_dir.join("rocksdb");
328        let db = DB::open(&opts, db_path)?;
329
330        // Load current commit timestamp
331        let commit_ts = if let Some(value) = db.get(b"__commit_ts__")? {
332            u64::from_be_bytes(value.try_into().unwrap_or([0; 8]))
333        } else {
334            0
335        };
336
337        Ok(Self {
338            db: Arc::new(db),
339            config,
340            commit_ts_counter: Arc::new(RwLock::new(commit_ts)),
341        })
342    }
343
344    /// Restore state from snapshot
345    pub fn restore_from_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
346        // Write all records from snapshot
347        for record in &snapshot.records {
348            self.write_state(record.clone())?;
349        }
350
351        // Update commit timestamp counter
352        let mut counter = self.commit_ts_counter.write().unwrap();
353        *counter = snapshot.metadata.snapshot_ts;
354        self.db.put(b"__commit_ts__", &counter.to_be_bytes())?;
355
356        self.flush()?;
357        Ok(())
358    }
359
360    /// Get path for snapshot file
361    fn snapshot_path(&self) -> PathBuf {
362        self.config.data_dir.join("snapshot.json")
363    }
364
365    fn state_key(record_id: &RecordId) -> Vec<u8> {
366        format!("state:{}:{}:{}", record_id.namespace, record_id.agent_id, record_id.key).into_bytes()
367    }
368
369    fn version_key(record_id: &RecordId, version: Version) -> Vec<u8> {
370        format!("version:{}:{}:{}:{:020}", record_id.namespace, record_id.agent_id, record_id.key, version).into_bytes()
371    }
372
373    fn event_key(commit_ts: CommitTs) -> Vec<u8> {
374        format!("event:{:020}", commit_ts).into_bytes()
375    }
376}
377
378impl Storage for RocksStorage {
379    fn health_check(&self) -> Result<()> {
380        // Try a simple read
381        self.db.get(b"__health__")?;
382        Ok(())
383    }
384
385    fn write_state(&self, record: StateRecord) -> Result<()> {
386        let record_id = RecordId::new(
387            record.namespace.clone(),
388            record.agent_id.clone(),
389            record.key.clone(),
390        );
391
392        // Write latest state
393        let state_key = Self::state_key(&record_id);
394        let state_value = serde_json::to_vec(&record)?;
395        self.db.put(&state_key, &state_value)?;
396
397        // Write versioned state
398        let version_key = Self::version_key(&record_id, record.version);
399        self.db.put(&version_key, &state_value)?;
400
401        if self.config.fsync_on_commit {
402            self.db.flush()?;
403        }
404
405        Ok(())
406    }
407
408    fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>> {
409        let key = Self::state_key(record_id);
410        if let Some(value) = self.db.get(&key)? {
411            let record: StateRecord = serde_json::from_slice(&value)?;
412            Ok(Some(record))
413        } else {
414            Ok(None)
415        }
416    }
417
418    fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>> {
419        let key = Self::version_key(record_id, version);
420        if let Some(value) = self.db.get(&key)? {
421            let record: StateRecord = serde_json::from_slice(&value)?;
422            Ok(Some(record))
423        } else {
424            Ok(None)
425        }
426    }
427
428    fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>> {
429        let prefix = format!("state:{}:{}:", namespace, agent_id);
430        let mut keys = Vec::new();
431
432        let iter = self.db.prefix_iterator(prefix.as_bytes());
433        for item in iter {
434            let (key, value) = item?;
435            let key_str = String::from_utf8_lossy(&key);
436            if !key_str.starts_with(&prefix) {
437                break;
438            }
439
440            let record: StateRecord = serde_json::from_slice(&value)?;
441            if !record.deleted {
442                keys.push(record.key);
443            }
444        }
445
446        Ok(keys)
447    }
448
449    fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>> {
450        let state_prefix = format!("state:{}:{}:{}", namespace, agent_id, prefix);
451        let mut records = Vec::new();
452
453        let iter = self.db.prefix_iterator(state_prefix.as_bytes());
454        for item in iter {
455            let (key, value) = item?;
456            let key_str = String::from_utf8_lossy(&key);
457            if !key_str.starts_with(&state_prefix) {
458                break;
459            }
460
461            let record: StateRecord = serde_json::from_slice(&value)?;
462            if !record.deleted {
463                records.push(record);
464            }
465        }
466
467        Ok(records)
468    }
469
470    fn append_event(&self, event: EventLogEntry) -> Result<()> {
471        let key = Self::event_key(event.commit_ts);
472        let value = serde_json::to_vec(&event)?;
473        self.db.put(&key, &value)?;
474
475        if self.config.fsync_on_commit {
476            self.db.flush()?;
477        }
478
479        Ok(())
480    }
481
482    fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>> {
483        let start_key = if let Some(ts) = start_ts {
484            Self::event_key(ts)
485        } else {
486            b"event:".to_vec()
487        };
488
489        let mut events = Vec::new();
490        let iter = self.db.prefix_iterator(&start_key);
491
492        for item in iter {
493            let (key, value) = item?;
494            let key_str = String::from_utf8_lossy(&key);
495            if !key_str.starts_with("event:") {
496                break;
497            }
498
499            let event: EventLogEntry = serde_json::from_slice(&value)?;
500
501            // Check if event is relevant to this agent
502            let relevant = event.operations.iter().any(|op| {
503                op.namespace == namespace && op.agent_id == agent_id
504            });
505
506            if relevant {
507                if let Some(end) = end_ts {
508                    if event.commit_ts > end {
509                        break;
510                    }
511                }
512                events.push(event);
513            }
514        }
515
516        Ok(events)
517    }
518
519    fn next_commit_ts(&self) -> Result<CommitTs> {
520        let mut counter = self.commit_ts_counter.write().unwrap();
521        *counter += 1;
522        let ts = *counter;
523
524        // Persist commit timestamp
525        self.db.put(b"__commit_ts__", &ts.to_be_bytes())?;
526
527        Ok(ts)
528    }
529
530    fn flush(&self) -> Result<()> {
531        self.db.flush()?;
532        Ok(())
533    }
534
535    fn create_snapshot(&self) -> Result<Snapshot> {
536        let commit_ts_counter = self.commit_ts_counter.read().unwrap();
537        let records = self.get_all_state()?;
538
539        let metadata = SnapshotMetadata {
540            version: SNAPSHOT_VERSION,
541            snapshot_ts: *commit_ts_counter,
542            record_count: records.len(),
543            created_at: std::time::SystemTime::now()
544                .duration_since(std::time::UNIX_EPOCH)
545                .unwrap()
546                .as_secs(),
547        };
548
549        Ok(Snapshot { metadata, records })
550    }
551
552    fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
553        let path = self.snapshot_path();
554        let json = serde_json::to_string_pretty(snapshot)?;
555        std::fs::write(path, json)?;
556        Ok(())
557    }
558
559    fn load_snapshot(&self) -> Result<Option<Snapshot>> {
560        let path = self.snapshot_path();
561        
562        if !path.exists() {
563            return Ok(None);
564        }
565
566        let json = std::fs::read_to_string(path)?;
567        let snapshot: Snapshot = serde_json::from_str(&json)?;
568
569        // Verify snapshot version
570        if snapshot.metadata.version != SNAPSHOT_VERSION {
571            return Err(anyhow::anyhow!(
572                "Snapshot version mismatch: expected {}, got {}",
573                SNAPSHOT_VERSION,
574                snapshot.metadata.version
575            ));
576        }
577
578        Ok(Some(snapshot))
579    }
580
581    fn get_all_state(&self) -> Result<Vec<StateRecord>> {
582        let mut records = Vec::new();
583        let iter = self.db.prefix_iterator(b"state:");
584
585        for item in iter {
586            let (key, value) = item?;
587            let key_str = String::from_utf8_lossy(&key);
588            
589            if !key_str.starts_with("state:") {
590                break;
591            }
592
593            let record: StateRecord = serde_json::from_slice(&value)?;
594            records.push(record);
595        }
596
597        Ok(records)
598    }
599}