1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::PathBuf;
7
8use crate::types::*;
9
10pub const SNAPSHOT_VERSION: u32 = 1;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StorageConfig {
16 pub data_dir: PathBuf,
18 pub fsync_on_commit: bool,
20 pub snapshot_interval: u64,
22 pub max_log_size: u64,
24}
25
26impl Default for StorageConfig {
27 fn default() -> Self {
28 Self {
29 data_dir: PathBuf::from("./data"),
30 fsync_on_commit: true,
31 snapshot_interval: 1000,
32 max_log_size: 100 * 1024 * 1024, }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StateRecord {
40 pub namespace: Namespace,
41 pub agent_id: AgentId,
42 pub key: Key,
43 pub value: Option<serde_json::Value>,
44 pub version: Version,
45 pub commit_ts: CommitTs,
46 pub deleted: bool,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct EventLogEntry {
52 pub txn_id: TxnId,
53 pub commit_ts: CommitTs,
54 pub operations: Vec<OperationRecord>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct OperationRecord {
59 pub namespace: Namespace,
60 pub agent_id: AgentId,
61 pub key: Key,
62 pub value: Option<serde_json::Value>,
63 pub version: Version,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct SnapshotMetadata {
69 pub version: u32,
71 pub snapshot_ts: CommitTs,
73 pub record_count: usize,
75 pub created_at: u64,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Snapshot {
82 pub metadata: SnapshotMetadata,
83 pub records: Vec<StateRecord>,
84}
85
86pub trait Storage: Send + Sync {
88 fn health_check(&self) -> Result<()>;
90
91 fn write_state(&self, record: StateRecord) -> Result<()>;
93
94 fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>>;
96
97 fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>>;
99
100 fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>>;
102
103 fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>>;
105
106 fn append_event(&self, event: EventLogEntry) -> Result<()>;
108
109 fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>>;
111
112 fn next_commit_ts(&self) -> Result<CommitTs>;
114
115 fn flush(&self) -> Result<()>;
117
118 fn create_snapshot(&self) -> Result<Snapshot>;
120
121 fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()>;
123
124 fn load_snapshot(&self) -> Result<Option<Snapshot>>;
126
127 fn get_all_state(&self) -> Result<Vec<StateRecord>>;
129}
130
131use std::sync::{Arc, RwLock};
136
137pub struct InMemoryStorage {
138 state: Arc<RwLock<HashMap<RecordId, Vec<StateRecord>>>>,
139 events: Arc<RwLock<Vec<EventLogEntry>>>,
140 commit_ts_counter: Arc<RwLock<CommitTs>>,
141}
142
143impl InMemoryStorage {
144 pub fn new() -> Self {
145 Self {
146 state: Arc::new(RwLock::new(HashMap::new())),
147 events: Arc::new(RwLock::new(Vec::new())),
148 commit_ts_counter: Arc::new(RwLock::new(0)),
149 }
150 }
151}
152
153impl Default for InMemoryStorage {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl Storage for InMemoryStorage {
160 fn health_check(&self) -> Result<()> {
161 Ok(())
162 }
163
164 fn write_state(&self, record: StateRecord) -> Result<()> {
165 let mut state = self.state.write().unwrap();
166 let record_id = RecordId::new(
167 record.namespace.clone(),
168 record.agent_id.clone(),
169 record.key.clone(),
170 );
171 state.entry(record_id).or_insert_with(Vec::new).push(record);
172 Ok(())
173 }
174
175 fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>> {
176 let state = self.state.read().unwrap();
177 Ok(state.get(record_id).and_then(|versions| versions.last().cloned()))
178 }
179
180 fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>> {
181 let state = self.state.read().unwrap();
182 Ok(state.get(record_id).and_then(|versions| {
183 versions.iter().find(|r| r.version == version).cloned()
184 }))
185 }
186
187 fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>> {
188 let state = self.state.read().unwrap();
189 let keys: Vec<String> = state
190 .iter()
191 .filter(|(id, versions)| {
192 id.namespace == namespace
193 && id.agent_id == agent_id
194 && versions.last().map(|r| !r.deleted).unwrap_or(false)
195 })
196 .map(|(id, _)| id.key.clone())
197 .collect();
198 Ok(keys)
199 }
200
201 fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>> {
202 let state = self.state.read().unwrap();
203 let records: Vec<StateRecord> = state
204 .iter()
205 .filter(|(id, _)| {
206 id.namespace == namespace
207 && id.agent_id == agent_id
208 && id.key.starts_with(prefix)
209 })
210 .filter_map(|(_, versions)| versions.last().cloned())
211 .filter(|r| !r.deleted)
212 .collect();
213 Ok(records)
214 }
215
216 fn append_event(&self, event: EventLogEntry) -> Result<()> {
217 let mut events = self.events.write().unwrap();
218 events.push(event);
219 Ok(())
220 }
221
222 fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>> {
223 let events = self.events.read().unwrap();
224 let filtered: Vec<EventLogEntry> = events
225 .iter()
226 .filter(|e| {
227 e.operations.iter().any(|op| {
228 op.namespace == namespace && op.agent_id == agent_id
229 })
230 })
231 .filter(|e| {
232 if let Some(start) = start_ts {
233 e.commit_ts >= start
234 } else {
235 true
236 }
237 })
238 .filter(|e| {
239 if let Some(end) = end_ts {
240 e.commit_ts <= end
241 } else {
242 true
243 }
244 })
245 .cloned()
246 .collect();
247 Ok(filtered)
248 }
249
250 fn next_commit_ts(&self) -> Result<CommitTs> {
251 let mut counter = self.commit_ts_counter.write().unwrap();
252 *counter += 1;
253 Ok(*counter)
254 }
255
256 fn flush(&self) -> Result<()> {
257 Ok(())
258 }
259
260 fn create_snapshot(&self) -> Result<Snapshot> {
261 let state = self.state.read().unwrap();
262 let commit_ts_counter = self.commit_ts_counter.read().unwrap();
263
264 let mut records = Vec::new();
266 for versions in state.values() {
267 if let Some(record) = versions.last() {
268 records.push(record.clone());
269 }
270 }
271
272 let metadata = SnapshotMetadata {
273 version: SNAPSHOT_VERSION,
274 snapshot_ts: *commit_ts_counter,
275 record_count: records.len(),
276 created_at: std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap()
279 .as_secs(),
280 };
281
282 Ok(Snapshot { metadata, records })
283 }
284
285 fn save_snapshot(&self, _snapshot: &Snapshot) -> Result<()> {
286 Ok(())
288 }
289
290 fn load_snapshot(&self) -> Result<Option<Snapshot>> {
291 Ok(None)
293 }
294
295 fn get_all_state(&self) -> Result<Vec<StateRecord>> {
296 let state = self.state.read().unwrap();
297 let mut records = Vec::new();
298 for versions in state.values() {
299 if let Some(record) = versions.last() {
300 records.push(record.clone());
301 }
302 }
303 Ok(records)
304 }
305}
306
307use rocksdb::{Options, DB};
312
313pub struct RocksStorage {
314 db: Arc<DB>,
315 config: StorageConfig,
316 commit_ts_counter: Arc<RwLock<CommitTs>>,
317}
318
319impl RocksStorage {
320 pub fn new(config: StorageConfig) -> Result<Self> {
321 std::fs::create_dir_all(&config.data_dir)?;
322
323 let mut opts = Options::default();
324 opts.create_if_missing(true);
325 opts.create_missing_column_families(true);
326
327 let db_path = config.data_dir.join("rocksdb");
328 let db = DB::open(&opts, db_path)?;
329
330 let commit_ts = if let Some(value) = db.get(b"__commit_ts__")? {
332 u64::from_be_bytes(value.try_into().unwrap_or([0; 8]))
333 } else {
334 0
335 };
336
337 Ok(Self {
338 db: Arc::new(db),
339 config,
340 commit_ts_counter: Arc::new(RwLock::new(commit_ts)),
341 })
342 }
343
344 pub fn restore_from_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
346 for record in &snapshot.records {
348 self.write_state(record.clone())?;
349 }
350
351 let mut counter = self.commit_ts_counter.write().unwrap();
353 *counter = snapshot.metadata.snapshot_ts;
354 self.db.put(b"__commit_ts__", &counter.to_be_bytes())?;
355
356 self.flush()?;
357 Ok(())
358 }
359
360 fn snapshot_path(&self) -> PathBuf {
362 self.config.data_dir.join("snapshot.json")
363 }
364
365 fn state_key(record_id: &RecordId) -> Vec<u8> {
366 format!("state:{}:{}:{}", record_id.namespace, record_id.agent_id, record_id.key).into_bytes()
367 }
368
369 fn version_key(record_id: &RecordId, version: Version) -> Vec<u8> {
370 format!("version:{}:{}:{}:{:020}", record_id.namespace, record_id.agent_id, record_id.key, version).into_bytes()
371 }
372
373 fn event_key(commit_ts: CommitTs) -> Vec<u8> {
374 format!("event:{:020}", commit_ts).into_bytes()
375 }
376}
377
378impl Storage for RocksStorage {
379 fn health_check(&self) -> Result<()> {
380 self.db.get(b"__health__")?;
382 Ok(())
383 }
384
385 fn write_state(&self, record: StateRecord) -> Result<()> {
386 let record_id = RecordId::new(
387 record.namespace.clone(),
388 record.agent_id.clone(),
389 record.key.clone(),
390 );
391
392 let state_key = Self::state_key(&record_id);
394 let state_value = serde_json::to_vec(&record)?;
395 self.db.put(&state_key, &state_value)?;
396
397 let version_key = Self::version_key(&record_id, record.version);
399 self.db.put(&version_key, &state_value)?;
400
401 if self.config.fsync_on_commit {
402 self.db.flush()?;
403 }
404
405 Ok(())
406 }
407
408 fn read_state(&self, record_id: &RecordId) -> Result<Option<StateRecord>> {
409 let key = Self::state_key(record_id);
410 if let Some(value) = self.db.get(&key)? {
411 let record: StateRecord = serde_json::from_slice(&value)?;
412 Ok(Some(record))
413 } else {
414 Ok(None)
415 }
416 }
417
418 fn read_state_at_version(&self, record_id: &RecordId, version: Version) -> Result<Option<StateRecord>> {
419 let key = Self::version_key(record_id, version);
420 if let Some(value) = self.db.get(&key)? {
421 let record: StateRecord = serde_json::from_slice(&value)?;
422 Ok(Some(record))
423 } else {
424 Ok(None)
425 }
426 }
427
428 fn list_keys(&self, namespace: &str, agent_id: &str) -> Result<Vec<String>> {
429 let prefix = format!("state:{}:{}:", namespace, agent_id);
430 let mut keys = Vec::new();
431
432 let iter = self.db.prefix_iterator(prefix.as_bytes());
433 for item in iter {
434 let (key, value) = item?;
435 let key_str = String::from_utf8_lossy(&key);
436 if !key_str.starts_with(&prefix) {
437 break;
438 }
439
440 let record: StateRecord = serde_json::from_slice(&value)?;
441 if !record.deleted {
442 keys.push(record.key);
443 }
444 }
445
446 Ok(keys)
447 }
448
449 fn scan_prefix(&self, namespace: &str, agent_id: &str, prefix: &str) -> Result<Vec<StateRecord>> {
450 let state_prefix = format!("state:{}:{}:{}", namespace, agent_id, prefix);
451 let mut records = Vec::new();
452
453 let iter = self.db.prefix_iterator(state_prefix.as_bytes());
454 for item in iter {
455 let (key, value) = item?;
456 let key_str = String::from_utf8_lossy(&key);
457 if !key_str.starts_with(&state_prefix) {
458 break;
459 }
460
461 let record: StateRecord = serde_json::from_slice(&value)?;
462 if !record.deleted {
463 records.push(record);
464 }
465 }
466
467 Ok(records)
468 }
469
470 fn append_event(&self, event: EventLogEntry) -> Result<()> {
471 let key = Self::event_key(event.commit_ts);
472 let value = serde_json::to_vec(&event)?;
473 self.db.put(&key, &value)?;
474
475 if self.config.fsync_on_commit {
476 self.db.flush()?;
477 }
478
479 Ok(())
480 }
481
482 fn replay_events(&self, namespace: &str, agent_id: &str, start_ts: Option<CommitTs>, end_ts: Option<CommitTs>) -> Result<Vec<EventLogEntry>> {
483 let start_key = if let Some(ts) = start_ts {
484 Self::event_key(ts)
485 } else {
486 b"event:".to_vec()
487 };
488
489 let mut events = Vec::new();
490 let iter = self.db.prefix_iterator(&start_key);
491
492 for item in iter {
493 let (key, value) = item?;
494 let key_str = String::from_utf8_lossy(&key);
495 if !key_str.starts_with("event:") {
496 break;
497 }
498
499 let event: EventLogEntry = serde_json::from_slice(&value)?;
500
501 let relevant = event.operations.iter().any(|op| {
503 op.namespace == namespace && op.agent_id == agent_id
504 });
505
506 if relevant {
507 if let Some(end) = end_ts {
508 if event.commit_ts > end {
509 break;
510 }
511 }
512 events.push(event);
513 }
514 }
515
516 Ok(events)
517 }
518
519 fn next_commit_ts(&self) -> Result<CommitTs> {
520 let mut counter = self.commit_ts_counter.write().unwrap();
521 *counter += 1;
522 let ts = *counter;
523
524 self.db.put(b"__commit_ts__", &ts.to_be_bytes())?;
526
527 Ok(ts)
528 }
529
530 fn flush(&self) -> Result<()> {
531 self.db.flush()?;
532 Ok(())
533 }
534
535 fn create_snapshot(&self) -> Result<Snapshot> {
536 let commit_ts_counter = self.commit_ts_counter.read().unwrap();
537 let records = self.get_all_state()?;
538
539 let metadata = SnapshotMetadata {
540 version: SNAPSHOT_VERSION,
541 snapshot_ts: *commit_ts_counter,
542 record_count: records.len(),
543 created_at: std::time::SystemTime::now()
544 .duration_since(std::time::UNIX_EPOCH)
545 .unwrap()
546 .as_secs(),
547 };
548
549 Ok(Snapshot { metadata, records })
550 }
551
552 fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
553 let path = self.snapshot_path();
554 let json = serde_json::to_string_pretty(snapshot)?;
555 std::fs::write(path, json)?;
556 Ok(())
557 }
558
559 fn load_snapshot(&self) -> Result<Option<Snapshot>> {
560 let path = self.snapshot_path();
561
562 if !path.exists() {
563 return Ok(None);
564 }
565
566 let json = std::fs::read_to_string(path)?;
567 let snapshot: Snapshot = serde_json::from_str(&json)?;
568
569 if snapshot.metadata.version != SNAPSHOT_VERSION {
571 return Err(anyhow::anyhow!(
572 "Snapshot version mismatch: expected {}, got {}",
573 SNAPSHOT_VERSION,
574 snapshot.metadata.version
575 ));
576 }
577
578 Ok(Some(snapshot))
579 }
580
581 fn get_all_state(&self) -> Result<Vec<StateRecord>> {
582 let mut records = Vec::new();
583 let iter = self.db.prefix_iterator(b"state:");
584
585 for item in iter {
586 let (key, value) = item?;
587 let key_str = String::from_utf8_lossy(&key);
588
589 if !key_str.starts_with("state:") {
590 break;
591 }
592
593 let record: StateRecord = serde_json::from_slice(&value)?;
594 records.push(record);
595 }
596
597 Ok(records)
598 }
599}