Skip to main content

smoo_gadget_core/
state_store.rs

1use anyhow::{Context, Result, anyhow, ensure};
2use bitflags::bitflags;
3use rand::{RngCore, rngs::OsRng};
4use serde::{Deserialize, Serialize};
5use std::{
6    fs::{self, File, OpenOptions},
7    io::{self, Write},
8    path::{Path, PathBuf},
9};
10
11const STATE_VERSION: u32 = 0;
12
13bitflags! {
14    #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
15    pub struct ExportFlags: u32 {
16        const READ_ONLY = 1 << 0;
17    }
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ExportSpec {
22    pub block_size: u32,
23    pub size_bytes: u64,
24    pub flags: ExportFlags,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub struct PersistedExportRecord {
29    pub export_id: u32,
30    pub spec: ExportSpec,
31    pub assigned_dev_id: Option<u32>,
32}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
35struct PersistedState {
36    version: u32,
37    session_id: u64,
38    exports: Vec<PersistedExportRecord>,
39}
40
41/// In-memory view of the persisted gadget state.
42///
43/// When `path` is `None`, persistence is disabled and `persist()` becomes a no-op.
44#[derive(Clone, Debug)]
45pub struct StateStore {
46    path: Option<PathBuf>,
47    session_id: u64,
48    records: Vec<PersistedExportRecord>,
49}
50
51impl Default for StateStore {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl StateStore {
58    /// Construct a fresh, in-memory store with a new session ID.
59    pub fn new() -> Self {
60        Self {
61            path: None,
62            session_id: generate_session_id(),
63            records: Vec::new(),
64        }
65    }
66
67    /// Construct a fresh, persistent store for `path` with a new session ID.
68    pub fn new_with_path(path: PathBuf) -> Self {
69        Self {
70            path: Some(path),
71            session_id: generate_session_id(),
72            records: Vec::new(),
73        }
74    }
75
76    /// Load state from `path`, returning an empty store with a new session ID when
77    /// the file does not exist.
78    pub fn load(path: PathBuf) -> Result<Self> {
79        match fs::read(&path) {
80            Ok(data) => {
81                let state: PersistedState =
82                    serde_json::from_slice(&data).context("decode state file")?;
83                ensure!(
84                    state.version == STATE_VERSION,
85                    "unsupported state version {}",
86                    state.version
87                );
88                Ok(Self {
89                    path: Some(path),
90                    session_id: state.session_id,
91                    records: state.exports,
92                })
93            }
94            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(Self {
95                path: Some(path),
96                session_id: generate_session_id(),
97                records: Vec::new(),
98            }),
99            Err(err) => Err(err).context("read state file"),
100        }
101    }
102
103    pub fn session_id(&self) -> u64 {
104        self.session_id
105    }
106
107    pub fn records(&self) -> &[PersistedExportRecord] {
108        &self.records
109    }
110
111    pub fn into_records(self) -> Vec<PersistedExportRecord> {
112        self.records
113    }
114
115    pub fn path(&self) -> Option<&Path> {
116        self.path.as_deref()
117    }
118
119    pub fn replace_all(&mut self, records: Vec<PersistedExportRecord>) {
120        self.records = records;
121    }
122
123    pub fn upsert_record(&mut self, record: PersistedExportRecord) {
124        match self
125            .records
126            .iter()
127            .position(|existing| existing.export_id == record.export_id)
128        {
129            Some(idx) => self.records[idx] = record,
130            None => self.records.push(record),
131        }
132    }
133
134    pub fn update_record<F>(&mut self, export_id: u32, f: F) -> Result<()>
135    where
136        F: FnOnce(&mut PersistedExportRecord),
137    {
138        let record = self
139            .records
140            .iter_mut()
141            .find(|record| record.export_id == export_id)
142            .ok_or_else(|| anyhow!("export {export_id} not found in state store"))?;
143        f(record);
144        Ok(())
145    }
146
147    pub fn remove_record(&mut self, export_id: u32) {
148        if let Some(idx) = self
149            .records
150            .iter()
151            .position(|record| record.export_id == export_id)
152        {
153            self.records.swap_remove(idx);
154        }
155    }
156
157    /// Persist the current snapshot to disk. No-op when persistence is disabled.
158    pub fn persist(&self) -> Result<()> {
159        let Some(path) = &self.path else {
160            return Ok(());
161        };
162
163        let state = PersistedState {
164            version: STATE_VERSION,
165            session_id: self.session_id,
166            exports: self.records.clone(),
167        };
168        let payload = serde_json::to_vec(&state).context("encode state snapshot")?;
169        let dir = path.parent().unwrap_or_else(|| Path::new("."));
170        fs::create_dir_all(dir).context("create state directory")?;
171        let dir_file = File::open(dir).context("open state directory for sync")?;
172
173        let tmp_path = path.with_extension("tmp");
174        {
175            let mut file = OpenOptions::new()
176                .create(true)
177                .truncate(true)
178                .write(true)
179                .open(&tmp_path)
180                .with_context(|| format!("open temporary state file {}", tmp_path.display()))?;
181            file.write_all(&payload)
182                .with_context(|| format!("write {}", tmp_path.display()))?;
183            file.sync_all()
184                .with_context(|| format!("flush {}", tmp_path.display()))?;
185        }
186
187        fs::rename(&tmp_path, path)
188            .with_context(|| format!("commit state file to {}", path.display()))?;
189        dir_file
190            .sync_all()
191            .context("sync state directory after rename")?;
192        Ok(())
193    }
194
195    /// Remove the state file from disk, if persistence is enabled.
196    pub fn remove_file(&self) -> Result<()> {
197        let Some(path) = &self.path else {
198            return Ok(());
199        };
200
201        match fs::remove_file(path) {
202            Ok(()) => {}
203            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
204            Err(err) => {
205                return Err(err).with_context(|| format!("remove state file {}", path.display()));
206            }
207        }
208
209        if let Some(dir) = path.parent() {
210            if let Ok(dir_file) = File::open(dir) {
211                let _ = dir_file.sync_all();
212            }
213        }
214        Ok(())
215    }
216}
217
218fn generate_session_id() -> u64 {
219    loop {
220        let candidate = OsRng.next_u64();
221        if candidate != 0 {
222            return candidate;
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use tempfile::tempdir;
231
232    #[test]
233    fn new_in_memory_has_session() {
234        let store = StateStore::new();
235        assert_ne!(store.session_id(), 0);
236        assert!(store.path().is_none());
237    }
238
239    #[test]
240    fn persist_round_trip() {
241        let dir = tempdir().unwrap();
242        let path = dir.path().join("state.json");
243        let mut store = StateStore::load(path.clone()).unwrap();
244        assert!(store.records().is_empty());
245
246        let spec = ExportSpec {
247            block_size: 4096,
248            size_bytes: 4096 * 8,
249            flags: ExportFlags::READ_ONLY,
250        };
251        let record = PersistedExportRecord {
252            export_id: 1,
253            spec,
254            assigned_dev_id: Some(7),
255        };
256        store.upsert_record(record.clone());
257        store.persist().unwrap();
258
259        let loaded = StateStore::load(path).unwrap();
260        assert_eq!(store.session_id(), loaded.session_id());
261        assert_eq!(loaded.records(), &[record]);
262    }
263
264    #[test]
265    fn load_missing_creates_new_session() {
266        let dir = tempdir().unwrap();
267        let path = dir.path().join("missing.json");
268        let store = StateStore::load(path).unwrap();
269        assert!(store.records().is_empty());
270        assert_ne!(store.session_id(), 0);
271    }
272}