statefile/
lib.rs

1use serde::de::DeserializeOwned;
2use serde::Serialize;
3use std::error::Error;
4use std::fs::OpenOptions;
5use std::io::prelude::*;
6use std::path::{Path, PathBuf};
7use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
8
9pub struct WriteGuard<'a, T: Serialize + DeserializeOwned + Default> {
10    guard: RwLockWriteGuard<'a, T>,
11    path: PathBuf,
12}
13
14impl<'a, T: Serialize + DeserializeOwned + Default> Drop for WriteGuard<'a, T> {
15    fn drop(&mut self) {
16        // convert data structure to pretty JSON string
17        let json = match serde_json::to_string_pretty(&*self.guard) {
18            Ok(v) => v,
19            Err(e) => {
20                log::error!("Failed to serialize JSON: {}", e);
21                return;
22            }
23        };
24
25        // open the state file
26        let path = self.path.clone();
27        let mut file = match OpenOptions::new().write(true).create(true).open(&path) {
28            Ok(v) => v,
29            Err(e) => {
30                log::error!("Failed to open file {}: {}", path.display(), e);
31                return;
32            }
33        };
34
35        // write to disk
36        if let Err(e) = file.write_all(json.as_bytes()) {
37            log::error!("Failed to write to file {}: {}", path.display(), e);
38            return;
39        }
40
41        // ensure data makes it to disk
42        if let Err(e) = file.flush() {
43            log::error!("Failed to flush file {}: {}", path.display(), e);
44            return;
45        }
46
47        log::info!("Data successfully written to file {}", path.display())
48    }
49}
50
51impl<'a, T: Serialize + DeserializeOwned + Default> std::ops::Deref for WriteGuard<'a, T> {
52    type Target = T;
53
54    fn deref(&self) -> &Self::Target {
55        &self.guard
56    }
57}
58
59impl<'a, T: Serialize + DeserializeOwned + Default> std::ops::DerefMut for WriteGuard<'a, T> {
60    fn deref_mut(&mut self) -> &mut Self::Target {
61        &mut self.guard
62    }
63}
64
65/// A state file.
66///
67/// This provides strongly typed access to a JSON file wrapped in a `RwLock`
68/// that writes to disk once write access is dropped.
69///
70/// ```rust
71/// use statefile::File;
72/// use serde::{Deserialize, Serialize};
73///
74/// // you must specify at least these derivations
75/// #[derive(Serialize, Deserialize, Default)]
76/// struct State {
77///     foo: String,
78///     bar: u32,
79/// }
80///
81/// #[tokio::main]
82/// async fn main() {
83///     // create or open state file at given path
84///     let mut state = File::<State>::new("mystate.json").await.unwrap();
85///     // if the file doesn't exist or is empty, State will contain default values
86///
87///     let mut write_guard = state.write().await; // grab write access
88///     write_guard.foo = "".to_string();
89///     write_guard.bar = 10;
90///     drop(write_guard); // write state by explicitly dropping
91/// }
92/// ```
93///
94pub struct File<T: Serialize + DeserializeOwned + Default> {
95    data: RwLock<T>,
96    path: PathBuf,
97}
98
99impl<T: Serialize + DeserializeOwned + Default> File<T> {
100    /// Create a new state file at the given path
101    pub async fn new(path: impl AsRef<Path> + Copy) -> Result<Self, Box<dyn Error>> {
102        let mut file = OpenOptions::new()
103            .read(true)
104            .write(true)
105            .create(true)
106            .open(path)?;
107
108        let mut contents = String::new();
109        file.read_to_string(&mut contents)?;
110
111        let data = if contents.is_empty() {
112            T::default()
113        } else {
114            serde_json::from_str(&contents)?
115        };
116
117        let data = RwLock::new(data);
118
119        let path = path.as_ref().to_path_buf();
120
121        Ok(File { data, path })
122    }
123
124    /// Locks this state file with shared read access, causing the current task
125    /// to yield until the lock has been acquired.
126    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
127        self.data.read().await
128    }
129
130    /// Locks this state file with exclusive write access, causing the current
131    /// task to yield until the lock has been acquired.
132    pub async fn write(&self) -> WriteGuard<'_, T> {
133        WriteGuard {
134            guard: self.data.write().await,
135            path: self.path.clone(),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use serde::{Deserialize, Serialize};
144    use std::fs;
145
146    #[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
147    struct TestData {
148        field1: String,
149        field2: u32,
150    }
151
152    #[tokio::test]
153    async fn test_file_create_and_write() {
154        let test_path = "test_file_create_and_write.json";
155        let file = File::<TestData>::new(test_path).await.unwrap();
156
157        let mut write_guard = file.write().await;
158        write_guard.field1 = String::from("Test String");
159        write_guard.field2 = 42;
160        drop(write_guard); // Forces the Drop trait to be called, data should be written to the file
161
162        let mut file_content = String::new();
163        std::fs::File::open(test_path)
164            .unwrap()
165            .read_to_string(&mut file_content)
166            .unwrap();
167
168        assert_eq!(
169            file_content,
170            r#"{
171  "field1": "Test String",
172  "field2": 42
173}"#
174        );
175
176        let _ = fs::remove_file(test_path); // Clean up test file
177    }
178
179    #[tokio::test]
180    async fn test_file_read() {
181        let test_path = "test_file_read.json";
182        std::fs::write(test_path, r#"{"field1":"Test String","field2":42}"#).unwrap(); // Write initial data
183
184        let file = File::<TestData>::new(test_path).await.unwrap();
185        let read_guard = file.read().await;
186
187        assert_eq!(read_guard.field1, "Test String");
188        assert_eq!(read_guard.field2, 42);
189
190        let _ = fs::remove_file(test_path); // Clean up test file
191    }
192
193    #[tokio::test]
194    async fn test_file_read_default() {
195        let test_path = "test_file_read_default.json";
196        std::fs::write(test_path, "").unwrap(); // Write empty file
197
198        let file = File::<TestData>::new(test_path).await.unwrap();
199        let read_guard = file.read().await;
200
201        // Check default values
202        assert_eq!(read_guard.field1, "");
203        assert_eq!(read_guard.field2, 0);
204
205        let _ = fs::remove_file(test_path); // Clean up test file
206    }
207}