1use anyhow::{Context, Result, anyhow, ensure};
2use bitflags::bitflags;
3use rand::{RngCore, rngs::OsRng};
4use serde::{Deserialize, Serialize};
5use std::{
6 fs::{self, File, OpenOptions},
7 io::{self, Write},
8 path::{Path, PathBuf},
9};
10
11const STATE_VERSION: u32 = 0;
12
13bitflags! {
14 #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
15 pub struct ExportFlags: u32 {
16 const READ_ONLY = 1 << 0;
17 }
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ExportSpec {
22 pub block_size: u32,
23 pub size_bytes: u64,
24 pub flags: ExportFlags,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
28pub struct PersistedExportRecord {
29 pub export_id: u32,
30 pub spec: ExportSpec,
31 pub assigned_dev_id: Option<u32>,
32}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
35struct PersistedState {
36 version: u32,
37 session_id: u64,
38 exports: Vec<PersistedExportRecord>,
39}
40
41#[derive(Clone, Debug)]
45pub struct StateStore {
46 path: Option<PathBuf>,
47 session_id: u64,
48 records: Vec<PersistedExportRecord>,
49}
50
51impl Default for StateStore {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl StateStore {
58 pub fn new() -> Self {
60 Self {
61 path: None,
62 session_id: generate_session_id(),
63 records: Vec::new(),
64 }
65 }
66
67 pub fn new_with_path(path: PathBuf) -> Self {
69 Self {
70 path: Some(path),
71 session_id: generate_session_id(),
72 records: Vec::new(),
73 }
74 }
75
76 pub fn load(path: PathBuf) -> Result<Self> {
79 match fs::read(&path) {
80 Ok(data) => {
81 let state: PersistedState =
82 serde_json::from_slice(&data).context("decode state file")?;
83 ensure!(
84 state.version == STATE_VERSION,
85 "unsupported state version {}",
86 state.version
87 );
88 Ok(Self {
89 path: Some(path),
90 session_id: state.session_id,
91 records: state.exports,
92 })
93 }
94 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(Self {
95 path: Some(path),
96 session_id: generate_session_id(),
97 records: Vec::new(),
98 }),
99 Err(err) => Err(err).context("read state file"),
100 }
101 }
102
103 pub fn session_id(&self) -> u64 {
104 self.session_id
105 }
106
107 pub fn records(&self) -> &[PersistedExportRecord] {
108 &self.records
109 }
110
111 pub fn into_records(self) -> Vec<PersistedExportRecord> {
112 self.records
113 }
114
115 pub fn path(&self) -> Option<&Path> {
116 self.path.as_deref()
117 }
118
119 pub fn replace_all(&mut self, records: Vec<PersistedExportRecord>) {
120 self.records = records;
121 }
122
123 pub fn upsert_record(&mut self, record: PersistedExportRecord) {
124 match self
125 .records
126 .iter()
127 .position(|existing| existing.export_id == record.export_id)
128 {
129 Some(idx) => self.records[idx] = record,
130 None => self.records.push(record),
131 }
132 }
133
134 pub fn update_record<F>(&mut self, export_id: u32, f: F) -> Result<()>
135 where
136 F: FnOnce(&mut PersistedExportRecord),
137 {
138 let record = self
139 .records
140 .iter_mut()
141 .find(|record| record.export_id == export_id)
142 .ok_or_else(|| anyhow!("export {export_id} not found in state store"))?;
143 f(record);
144 Ok(())
145 }
146
147 pub fn remove_record(&mut self, export_id: u32) {
148 if let Some(idx) = self
149 .records
150 .iter()
151 .position(|record| record.export_id == export_id)
152 {
153 self.records.swap_remove(idx);
154 }
155 }
156
157 pub fn persist(&self) -> Result<()> {
159 let Some(path) = &self.path else {
160 return Ok(());
161 };
162
163 let state = PersistedState {
164 version: STATE_VERSION,
165 session_id: self.session_id,
166 exports: self.records.clone(),
167 };
168 let payload = serde_json::to_vec(&state).context("encode state snapshot")?;
169 let dir = path.parent().unwrap_or_else(|| Path::new("."));
170 fs::create_dir_all(dir).context("create state directory")?;
171 let dir_file = File::open(dir).context("open state directory for sync")?;
172
173 let tmp_path = path.with_extension("tmp");
174 {
175 let mut file = OpenOptions::new()
176 .create(true)
177 .truncate(true)
178 .write(true)
179 .open(&tmp_path)
180 .with_context(|| format!("open temporary state file {}", tmp_path.display()))?;
181 file.write_all(&payload)
182 .with_context(|| format!("write {}", tmp_path.display()))?;
183 file.sync_all()
184 .with_context(|| format!("flush {}", tmp_path.display()))?;
185 }
186
187 fs::rename(&tmp_path, path)
188 .with_context(|| format!("commit state file to {}", path.display()))?;
189 dir_file
190 .sync_all()
191 .context("sync state directory after rename")?;
192 Ok(())
193 }
194
195 pub fn remove_file(&self) -> Result<()> {
197 let Some(path) = &self.path else {
198 return Ok(());
199 };
200
201 match fs::remove_file(path) {
202 Ok(()) => {}
203 Err(err) if err.kind() == io::ErrorKind::NotFound => {}
204 Err(err) => {
205 return Err(err).with_context(|| format!("remove state file {}", path.display()));
206 }
207 }
208
209 if let Some(dir) = path.parent() {
210 if let Ok(dir_file) = File::open(dir) {
211 let _ = dir_file.sync_all();
212 }
213 }
214 Ok(())
215 }
216}
217
218fn generate_session_id() -> u64 {
219 loop {
220 let candidate = OsRng.next_u64();
221 if candidate != 0 {
222 return candidate;
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use tempfile::tempdir;
231
232 #[test]
233 fn new_in_memory_has_session() {
234 let store = StateStore::new();
235 assert_ne!(store.session_id(), 0);
236 assert!(store.path().is_none());
237 }
238
239 #[test]
240 fn persist_round_trip() {
241 let dir = tempdir().unwrap();
242 let path = dir.path().join("state.json");
243 let mut store = StateStore::load(path.clone()).unwrap();
244 assert!(store.records().is_empty());
245
246 let spec = ExportSpec {
247 block_size: 4096,
248 size_bytes: 4096 * 8,
249 flags: ExportFlags::READ_ONLY,
250 };
251 let record = PersistedExportRecord {
252 export_id: 1,
253 spec,
254 assigned_dev_id: Some(7),
255 };
256 store.upsert_record(record.clone());
257 store.persist().unwrap();
258
259 let loaded = StateStore::load(path).unwrap();
260 assert_eq!(store.session_id(), loaded.session_id());
261 assert_eq!(loaded.records(), &[record]);
262 }
263
264 #[test]
265 fn load_missing_creates_new_session() {
266 let dir = tempdir().unwrap();
267 let path = dir.path().join("missing.json");
268 let store = StateStore::load(path).unwrap();
269 assert!(store.records().is_empty());
270 assert_ne!(store.session_id(), 0);
271 }
272}