1use std::{
2 fs::{self, File},
3 io::{BufReader, BufWriter},
4 path::PathBuf,
5};
6
7use async_trait::async_trait;
8use bincode::{deserialize_from, serialize_into, Error as BincodeError};
9use serde::{de::DeserializeOwned, Deserialize, Serialize};
10use shuttle_service::{DeploymentMetadata, ResourceFactory, ResourceInputBuilder};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum PersistError {
15 #[error("invalid key name")]
16 InvalidKey,
17 #[error("failed to open file: {0}")]
18 Open(std::io::Error),
19 #[error("failed to create folder: {0}")]
20 CreateFolder(std::io::Error),
21 #[error("failed to list contents of folder: {0}")]
22 ListFolder(std::io::Error),
23 #[error("failed to list file name: {0}")]
24 ListName(String),
25 #[error("failed to clear folder: {0}")]
26 RemoveFolder(std::io::Error),
27 #[error("failed to remove file: {0}")]
28 RemoveFile(std::io::Error),
29 #[error("failed to serialize data: {0}")]
30 Serialize(BincodeError),
31 #[error("failed to deserialize data: {0}")]
32 Deserialize(BincodeError),
33}
34
35#[derive(Default)]
36pub struct Persist;
37
38#[derive(Debug, Deserialize, Serialize, Clone)]
39pub struct PersistInstance {
40 dir: PathBuf,
41}
42
43impl PersistInstance {
44 pub fn new(dir: PathBuf) -> Result<Self, PersistError> {
46 fs::create_dir_all(&dir).map_err(PersistError::CreateFolder)?;
47
48 Ok(Self { dir })
49 }
50
51 pub fn save<T: Serialize>(&self, key: &str, value: T) -> Result<(), PersistError> {
53 let file_path = self.get_storage_file(key)?;
54 let file = File::create(file_path).map_err(PersistError::Open)?;
55 let mut writer = BufWriter::new(file);
56
57 serialize_into(&mut writer, &value).map_err(PersistError::Serialize)
58 }
59
60 fn entries(&self) -> Result<std::fs::ReadDir, PersistError> {
61 fs::read_dir(&self.dir).map_err(PersistError::ListFolder)
62 }
63
64 pub fn size(&self) -> Result<usize, PersistError> {
66 Ok(self.entries()?.count())
67 }
68
69 pub fn list(&self) -> Result<Vec<String>, PersistError> {
71 self.entries()?
72 .map(|entry| {
73 entry
74 .map_err(PersistError::ListFolder)?
75 .path()
76 .file_stem()
77 .unwrap_or_default()
78 .to_str()
79 .map(ToString::to_string)
80 .ok_or(PersistError::ListName(
81 "the file name contains invalid characters".to_owned(),
82 ))
83 })
84 .collect()
85 }
86
87 pub fn clear(&self) -> Result<(), PersistError> {
89 fs::remove_dir_all(&self.dir).map_err(PersistError::RemoveFolder)?;
90 fs::create_dir_all(&self.dir).map_err(PersistError::CreateFolder)?;
91
92 Ok(())
93 }
94
95 pub fn remove(&self, key: &str) -> Result<(), PersistError> {
97 let file_path = self.get_storage_file(key)?;
98 fs::remove_file(file_path).map_err(PersistError::RemoveFile)?;
99
100 Ok(())
101 }
102
103 pub fn load<T>(&self, key: &str) -> Result<T, PersistError>
105 where
106 T: DeserializeOwned,
107 {
108 let file_path = self.get_storage_file(key)?;
109 let file = File::open(file_path).map_err(PersistError::Open)?;
110 let reader = BufReader::new(file);
111
112 Ok(deserialize_from(reader).map_err(PersistError::Deserialize))?
113 }
114
115 fn get_storage_file(&self, key: &str) -> Result<PathBuf, PersistError> {
116 let p = self.dir.join(format!("{key}.bin"));
117 if p.parent().unwrap() != self.dir {
118 Err(PersistError::InvalidKey)
119 } else {
120 Ok(p)
121 }
122 }
123}
124
125#[async_trait]
126impl ResourceInputBuilder for Persist {
127 type Input = PersistInstance;
128 type Output = PersistInstance;
129
130 async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, shuttle_service::Error> {
131 let DeploymentMetadata {
132 project_name,
133 storage_path,
134 ..
135 } = factory.get_metadata();
136
137 PersistInstance::new(
138 storage_path
139 .join(PathBuf::from("shuttle-persist"))
140 .join(PathBuf::from(project_name)), )
142 .map_err(|e| shuttle_service::Error::Custom(e.into()))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 fn setup(s: &str) -> PersistInstance {
151 let path = PathBuf::from(format!("test_output/{s}"));
152 let _ = std::fs::remove_dir_all(&path);
153
154 PersistInstance::new(path).unwrap()
155 }
156
157 #[test]
158 fn test_save_and_load() {
159 let persist = setup("test_save_and_load");
160
161 persist.save("test", "test").unwrap();
162 let result: String = persist.load("test").unwrap();
163 assert_eq!(result, "test");
164 }
165
166 #[test]
167 fn test_size() {
168 let persist = setup("test_size");
169
170 assert_eq!(persist.size().unwrap(), 0);
171 persist.save("test", "test").unwrap();
172 assert_eq!(persist.size().unwrap(), 1);
173 persist.save("test", "test2").unwrap(); assert_eq!(persist.size().unwrap(), 1);
175 persist.remove("test").unwrap();
176 assert_eq!(persist.size().unwrap(), 0);
177 }
178
179 #[test]
180 fn test_list() {
181 let persist = setup("test_list");
182
183 assert_eq!(persist.list().unwrap(), Vec::<String>::new());
184 persist.save("test", "test").unwrap();
185 assert_eq!(
186 persist.list().unwrap(),
187 Vec::<String>::from(["test".to_owned()])
188 );
189 persist.remove("test").unwrap();
190 assert_eq!(persist.list().unwrap(), Vec::<String>::new());
191 }
192
193 #[test]
194 fn test_remove() {
195 let persist = setup("test_remove");
196
197 persist.save("test", "test").unwrap();
198 persist.save("test2", "test2").unwrap();
199 persist.remove(persist.list().unwrap()[0].as_str()).unwrap();
200 assert_eq!(persist.size().unwrap(), 1);
201 }
202
203 #[test]
204 fn test_remove_error() {
205 let persist = setup("test_remove_error");
206
207 assert!(persist.remove("test").is_err());
208 }
209
210 #[test]
211 fn test_clear() {
212 let persist = setup("test_clear");
213
214 persist.save("test", "test").unwrap();
215 persist.clear().unwrap();
216 assert_eq!(persist.size().unwrap(), 0);
217 }
218
219 #[test]
220 fn test_load_error() {
221 let persist = setup("test_load_error");
222
223 assert!(persist.load::<String>("error").is_err());
224 }
225
226 #[test]
227 fn test_weird_keys() {
228 let persist = setup("test_weird_keys");
229
230 assert!(persist.save(".", "test").is_ok());
233 assert!(persist.save("\\", "test").is_ok());
234
235 assert!(persist.save("test/test", "test").is_err());
236 assert!(persist.save("../test", "test").is_err());
237 assert!(persist.save("/test", "test").is_err());
238 assert!(persist.save("~/test", "test").is_err());
239 }
240}