shuttle_persist/
lib.rs

1use std::{
2    fs::{self, File},
3    io::{BufReader, BufWriter},
4    path::PathBuf,
5};
6
7use async_trait::async_trait;
8use bincode::{deserialize_from, serialize_into, Error as BincodeError};
9use serde::{de::DeserializeOwned, Deserialize, Serialize};
10use shuttle_service::{DeploymentMetadata, ResourceFactory, ResourceInputBuilder};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum PersistError {
15    #[error("invalid key name")]
16    InvalidKey,
17    #[error("failed to open file: {0}")]
18    Open(std::io::Error),
19    #[error("failed to create folder: {0}")]
20    CreateFolder(std::io::Error),
21    #[error("failed to list contents of folder: {0}")]
22    ListFolder(std::io::Error),
23    #[error("failed to list file name: {0}")]
24    ListName(String),
25    #[error("failed to clear folder: {0}")]
26    RemoveFolder(std::io::Error),
27    #[error("failed to remove file: {0}")]
28    RemoveFile(std::io::Error),
29    #[error("failed to serialize data: {0}")]
30    Serialize(BincodeError),
31    #[error("failed to deserialize data: {0}")]
32    Deserialize(BincodeError),
33}
34
35#[derive(Default)]
36pub struct Persist;
37
38#[derive(Debug, Deserialize, Serialize, Clone)]
39pub struct PersistInstance {
40    dir: PathBuf,
41}
42
43impl PersistInstance {
44    /// Constructs a PersistInstance and creates its associated storage folder
45    pub fn new(dir: PathBuf) -> Result<Self, PersistError> {
46        fs::create_dir_all(&dir).map_err(PersistError::CreateFolder)?;
47
48        Ok(Self { dir })
49    }
50
51    /// Save a key-value pair to disk
52    pub fn save<T: Serialize>(&self, key: &str, value: T) -> Result<(), PersistError> {
53        let file_path = self.get_storage_file(key)?;
54        let file = File::create(file_path).map_err(PersistError::Open)?;
55        let mut writer = BufWriter::new(file);
56
57        serialize_into(&mut writer, &value).map_err(PersistError::Serialize)
58    }
59
60    fn entries(&self) -> Result<std::fs::ReadDir, PersistError> {
61        fs::read_dir(&self.dir).map_err(PersistError::ListFolder)
62    }
63
64    /// Returns the number of keys in this instance
65    pub fn size(&self) -> Result<usize, PersistError> {
66        Ok(self.entries()?.count())
67    }
68
69    /// Returns a vector of strings containing all the keys in this instance
70    pub fn list(&self) -> Result<Vec<String>, PersistError> {
71        self.entries()?
72            .map(|entry| {
73                entry
74                    .map_err(PersistError::ListFolder)?
75                    .path()
76                    .file_stem()
77                    .unwrap_or_default()
78                    .to_str()
79                    .map(ToString::to_string)
80                    .ok_or(PersistError::ListName(
81                        "the file name contains invalid characters".to_owned(),
82                    ))
83            })
84            .collect()
85    }
86
87    /// Removes all keys
88    pub fn clear(&self) -> Result<(), PersistError> {
89        fs::remove_dir_all(&self.dir).map_err(PersistError::RemoveFolder)?;
90        fs::create_dir_all(&self.dir).map_err(PersistError::CreateFolder)?;
91
92        Ok(())
93    }
94
95    /// Deletes a key from the PersistInstance
96    pub fn remove(&self, key: &str) -> Result<(), PersistError> {
97        let file_path = self.get_storage_file(key)?;
98        fs::remove_file(file_path).map_err(PersistError::RemoveFile)?;
99
100        Ok(())
101    }
102
103    /// Loads a value from disk
104    pub fn load<T>(&self, key: &str) -> Result<T, PersistError>
105    where
106        T: DeserializeOwned,
107    {
108        let file_path = self.get_storage_file(key)?;
109        let file = File::open(file_path).map_err(PersistError::Open)?;
110        let reader = BufReader::new(file);
111
112        Ok(deserialize_from(reader).map_err(PersistError::Deserialize))?
113    }
114
115    fn get_storage_file(&self, key: &str) -> Result<PathBuf, PersistError> {
116        let p = self.dir.join(format!("{key}.bin"));
117        if p.parent().unwrap() != self.dir {
118            Err(PersistError::InvalidKey)
119        } else {
120            Ok(p)
121        }
122    }
123}
124
125#[async_trait]
126impl ResourceInputBuilder for Persist {
127    type Input = PersistInstance;
128    type Output = PersistInstance;
129
130    async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, shuttle_service::Error> {
131        let DeploymentMetadata {
132            project_name,
133            storage_path,
134            ..
135        } = factory.get_metadata();
136
137        PersistInstance::new(
138            storage_path
139                .join(PathBuf::from("shuttle-persist"))
140                .join(PathBuf::from(project_name)), // separate persist directories per service
141        )
142        .map_err(|e| shuttle_service::Error::Custom(e.into()))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    fn setup(s: &str) -> PersistInstance {
151        let path = PathBuf::from(format!("test_output/{s}"));
152        let _ = std::fs::remove_dir_all(&path);
153
154        PersistInstance::new(path).unwrap()
155    }
156
157    #[test]
158    fn test_save_and_load() {
159        let persist = setup("test_save_and_load");
160
161        persist.save("test", "test").unwrap();
162        let result: String = persist.load("test").unwrap();
163        assert_eq!(result, "test");
164    }
165
166    #[test]
167    fn test_size() {
168        let persist = setup("test_size");
169
170        assert_eq!(persist.size().unwrap(), 0);
171        persist.save("test", "test").unwrap();
172        assert_eq!(persist.size().unwrap(), 1);
173        persist.save("test", "test2").unwrap(); // overwrite
174        assert_eq!(persist.size().unwrap(), 1);
175        persist.remove("test").unwrap();
176        assert_eq!(persist.size().unwrap(), 0);
177    }
178
179    #[test]
180    fn test_list() {
181        let persist = setup("test_list");
182
183        assert_eq!(persist.list().unwrap(), Vec::<String>::new());
184        persist.save("test", "test").unwrap();
185        assert_eq!(
186            persist.list().unwrap(),
187            Vec::<String>::from(["test".to_owned()])
188        );
189        persist.remove("test").unwrap();
190        assert_eq!(persist.list().unwrap(), Vec::<String>::new());
191    }
192
193    #[test]
194    fn test_remove() {
195        let persist = setup("test_remove");
196
197        persist.save("test", "test").unwrap();
198        persist.save("test2", "test2").unwrap();
199        persist.remove(persist.list().unwrap()[0].as_str()).unwrap();
200        assert_eq!(persist.size().unwrap(), 1);
201    }
202
203    #[test]
204    fn test_remove_error() {
205        let persist = setup("test_remove_error");
206
207        assert!(persist.remove("test").is_err());
208    }
209
210    #[test]
211    fn test_clear() {
212        let persist = setup("test_clear");
213
214        persist.save("test", "test").unwrap();
215        persist.clear().unwrap();
216        assert_eq!(persist.size().unwrap(), 0);
217    }
218
219    #[test]
220    fn test_load_error() {
221        let persist = setup("test_load_error");
222
223        assert!(persist.load::<String>("error").is_err());
224    }
225
226    #[test]
227    fn test_weird_keys() {
228        let persist = setup("test_weird_keys");
229
230        // Linux is the main concern
231
232        assert!(persist.save(".", "test").is_ok());
233        assert!(persist.save("\\", "test").is_ok());
234
235        assert!(persist.save("test/test", "test").is_err());
236        assert!(persist.save("../test", "test").is_err());
237        assert!(persist.save("/test", "test").is_err());
238        assert!(persist.save("~/test", "test").is_err());
239    }
240}