1use std::io::{Cursor, Read, Write};
4
5use super::vfs::{Node, Vfs};
6
7const MAGIC: &[u8; 4] = b"DEVS";
8const VERSION: u8 = 1;
9
10const NODE_DIR: u8 = 0;
11const NODE_FILE: u8 = 1;
12
13#[derive(Debug)]
14pub enum Error {
15 InvalidMagic,
16 InvalidVersion,
17 Truncated,
18 InvalidUtf8(std::string::FromUtf8Error),
19 Io(std::io::Error),
20 HostBacked,
22}
23
24impl std::fmt::Display for Error {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 Self::InvalidMagic => write!(f, "invalid magic"),
28 Self::InvalidVersion => write!(f, "invalid version"),
29 Self::Truncated => write!(f, "truncated data"),
30 Self::InvalidUtf8(e) => write!(f, "invalid utf-8: {e}"),
31 Self::Io(e) => write!(f, "io error: {e}"),
32 Self::HostBacked => {
33 f.write_str("host-backed workspace cannot be saved as .dev_shell.bin")
34 }
35 }
36 }
37}
38
39impl std::error::Error for Error {
40 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41 match self {
42 Self::InvalidUtf8(e) => Some(e),
43 Self::Io(e) => Some(e),
44 _ => None,
45 }
46 }
47}
48
49impl From<std::io::Error> for Error {
50 fn from(e: std::io::Error) -> Self {
51 Self::Io(e)
52 }
53}
54
55fn write_u32_le(w: &mut impl Write, n: u32) -> std::io::Result<()> {
56 w.write_all(&n.to_le_bytes())
57}
58
59fn write_u16_le(w: &mut impl Write, n: u16) -> std::io::Result<()> {
60 w.write_all(&n.to_le_bytes())
61}
62
63fn write_u64_le(w: &mut impl Write, n: u64) -> std::io::Result<()> {
64 w.write_all(&n.to_le_bytes())
65}
66
67fn read_u32_le(r: &mut impl Read) -> std::io::Result<u32> {
68 let mut buf = [0u8; 4];
69 r.read_exact(&mut buf)?;
70 Ok(u32::from_le_bytes(buf))
71}
72
73fn read_u16_le(r: &mut impl Read) -> std::io::Result<u16> {
74 let mut buf = [0u8; 2];
75 r.read_exact(&mut buf)?;
76 Ok(u16::from_le_bytes(buf))
77}
78
79fn read_u64_le(r: &mut impl Read) -> std::io::Result<u64> {
80 let mut buf = [0u8; 8];
81 r.read_exact(&mut buf)?;
82 Ok(u64::from_le_bytes(buf))
83}
84
85fn serialize_node(w: &mut impl Write, node: &Node) -> std::io::Result<()> {
86 match node {
87 Node::Dir { name, children } => {
88 w.write_all(&[NODE_DIR])?;
89 let name_bytes = name.as_bytes();
90 let len_u16 = u16::try_from(name_bytes.len())
91 .map_err(|_| std::io::Error::other("name len overflow"))?;
92 write_u16_le(w, len_u16)?;
93 w.write_all(name_bytes)?;
94 let len_u32 = u32::try_from(children.len())
95 .map_err(|_| std::io::Error::other("children len overflow"))?;
96 write_u32_le(w, len_u32)?;
97 for child in children {
98 serialize_node(w, child)?;
99 }
100 }
101 Node::File { name, content } => {
102 w.write_all(&[NODE_FILE])?;
103 let name_bytes = name.as_bytes();
104 let len_u16 = u16::try_from(name_bytes.len())
105 .map_err(|_| std::io::Error::other("name len overflow"))?;
106 write_u16_le(w, len_u16)?;
107 w.write_all(name_bytes)?;
108 write_u64_le(w, content.len() as u64)?;
109 w.write_all(content)?;
110 }
111 }
112 Ok(())
113}
114
115fn deserialize_node(r: &mut impl Read) -> Result<Node, Error> {
116 let mut tag = [0u8; 1];
117 r.read_exact(&mut tag).map_err(|e| {
118 if e.kind() == std::io::ErrorKind::UnexpectedEof {
119 Error::Truncated
120 } else {
121 Error::Io(e)
122 }
123 })?;
124 let name_len = read_u16_le(r)?;
125 let mut name_buf = vec![0u8; name_len as usize];
126 r.read_exact(&mut name_buf)?;
127 let name = String::from_utf8(name_buf).map_err(Error::InvalidUtf8)?;
128
129 match tag[0] {
130 NODE_DIR => {
131 let child_count = read_u32_le(r)?;
132 let n = usize::try_from(child_count).map_err(|_| Error::Truncated)?;
133 let mut children = Vec::with_capacity(n);
134 for _ in 0..child_count {
135 children.push(deserialize_node(r)?);
136 }
137 Ok(Node::Dir { name, children })
138 }
139 NODE_FILE => {
140 let content_len = read_u64_le(r)?;
141 let n = usize::try_from(content_len).map_err(|_| Error::Truncated)?;
142 let mut content = vec![0u8; n];
143 r.read_exact(&mut content)?;
144 Ok(Node::File { name, content })
145 }
146 _ => Err(Error::Truncated),
147 }
148}
149
150pub fn serialize(vfs: &Vfs) -> Result<Vec<u8>, Error> {
155 if vfs.is_host_backed() {
156 return Err(Error::HostBacked);
157 }
158 let mut out = Vec::new();
159 out.write_all(MAGIC)?;
160 out.write_all(&[VERSION])?;
161 let cwd = vfs.cwd().as_bytes();
162 let cwd_len = u32::try_from(cwd.len())
163 .map_err(|_| Error::Io(std::io::Error::other("cwd len overflow")))?;
164 write_u32_le(&mut out, cwd_len)?;
165 out.write_all(cwd)?;
166 serialize_node(&mut out, vfs.root())?;
167 Ok(out)
168}
169
170pub fn deserialize(bytes: &[u8]) -> Result<Vfs, Error> {
175 let mut r = Cursor::new(bytes);
176 let mut magic = [0u8; 4];
177 r.read_exact(&mut magic).map_err(|_| Error::Truncated)?;
178 if &magic != MAGIC {
179 return Err(Error::InvalidMagic);
180 }
181 let mut ver = [0u8; 1];
182 r.read_exact(&mut ver).map_err(|_| Error::Truncated)?;
183 if ver[0] != VERSION {
184 return Err(Error::InvalidVersion);
185 }
186 let cwd_len = read_u32_le(&mut r)?;
187 let cwd_len_usize = usize::try_from(cwd_len).map_err(|_| Error::Truncated)?;
188 let mut cwd_buf = vec![0u8; cwd_len_usize];
189 r.read_exact(&mut cwd_buf)?;
190 let cwd = String::from_utf8(cwd_buf).map_err(Error::InvalidUtf8)?;
191 let root = deserialize_node(&mut r)?;
192 Ok(Vfs::from_parts(root, cwd))
193}
194
195pub fn save_to_file(vfs: &Vfs, path: &std::path::Path) -> std::io::Result<()> {
200 let bytes = serialize(vfs).map_err(std::io::Error::other)?;
201 std::fs::write(path, bytes)
202}
203
204pub fn load_from_file(path: &std::path::Path) -> std::io::Result<Vfs> {
209 let bytes = std::fs::read(path)?;
210 deserialize(&bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn roundtrip_empty_vfs() {
219 let vfs = Vfs::new();
220 let bytes = serialize(&vfs).unwrap();
221 let vfs2 = deserialize(&bytes).unwrap();
222 assert_eq!(vfs.cwd(), vfs2.cwd());
223 }
224
225 #[test]
226 fn roundtrip_with_dir_and_file() {
227 let mut vfs = Vfs::new();
228 vfs.mkdir("/foo").unwrap();
229 vfs.write_file("/foo/bar", b"content").unwrap();
230 let bytes = serialize(&vfs).unwrap();
231 let vfs2 = deserialize(&bytes).unwrap();
232 assert_eq!(vfs2.read_file("/foo/bar").unwrap(), b"content");
233 }
234
235 #[test]
236 fn invalid_magic() {
237 let bytes = b"XXXX\x01\x00\x00\x00\x01/\x00\x00\x00\x00";
238 assert!(matches!(deserialize(bytes), Err(Error::InvalidMagic)));
239 }
240
241 #[test]
242 fn error_display() {
243 assert_eq!(Error::InvalidMagic.to_string(), "invalid magic");
244 assert_eq!(Error::InvalidVersion.to_string(), "invalid version");
245 assert_eq!(Error::Truncated.to_string(), "truncated data");
246 assert!(Error::Io(std::io::Error::other("e"))
247 .to_string()
248 .contains("io error"));
249 let utf8_err = Vec::<u8>::from([0xff, 0xfe]);
250 let e = String::from_utf8(utf8_err).unwrap_err();
251 assert!(Error::InvalidUtf8(e).to_string().contains("utf-8"));
252 }
253
254 #[test]
255 fn error_source() {
256 use std::error::Error as _;
257 let utf8_err = Vec::<u8>::from([0xff, 0xfe]);
258 let e = Error::InvalidUtf8(String::from_utf8(utf8_err).unwrap_err());
259 assert!(e.source().is_some());
260 let e = Error::Io(std::io::Error::other("test"));
261 assert!(e.source().is_some());
262 assert!(Error::InvalidMagic.source().is_none());
263 }
264
265 #[test]
266 fn deserialize_invalid_tag() {
267 let mut bytes = serialize(&Vfs::new()).unwrap();
268 let root_tag_offset = 4 + 1 + 4 + 1; bytes[root_tag_offset] = 0xff; assert!(matches!(deserialize(&bytes), Err(Error::Truncated)));
271 }
272
273 #[test]
274 fn invalid_version() {
275 let mut bytes = vec![b'D', b'E', b'V', b'S', 99, 0, 0, 0, 1, b'/'];
276 bytes.extend_from_slice(&0u32.to_le_bytes());
277 assert!(matches!(deserialize(&bytes), Err(Error::InvalidVersion)));
278 }
279
280 #[test]
281 fn load_from_file_nonexistent() {
282 let path = std::path::Path::new("/nonexistent_devshell_path_12345");
283 let r = load_from_file(path);
284 assert!(r.is_err());
285 }
286
287 #[test]
288 fn error_from_io() {
289 let e: Error = std::io::Error::other("test").into();
290 assert!(matches!(e, Error::Io(_)));
291 }
292
293 #[test]
295 fn deserialize_truncated_after_cwd() {
296 let bytes = vec![b'D', b'E', b'V', b'S', 1, 0, 0, 0, 1, b'/']; let r = deserialize(&bytes);
298 assert!(r.is_err(), "expected Err (Truncated or Io)");
299 }
300
301 #[test]
303 fn deserialize_truncated_inside_node() {
304 let bytes: Vec<u8> = [b'D', b'E', b'V', b'S', 1, 0, 0, 0, 0, 0, 0, 0]
306 .into_iter()
307 .collect();
308 let r = deserialize(&bytes);
309 assert!(r.is_err(), "expected Err (Truncated or Io)");
310 }
311
312 #[test]
314 fn deserialize_node_tag_unexpected_eof_is_truncated() {
315 let mut r = Cursor::new(&[]);
316 let e = deserialize_node(&mut r).unwrap_err();
317 assert!(matches!(e, Error::Truncated));
318 }
319
320 #[test]
322 fn deserialize_node_tag_other_io_error_is_io() {
323 struct FailRead;
324 impl std::io::Read for FailRead {
325 fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
326 Err(std::io::Error::other("injected read failure"))
327 }
328 }
329 let e = deserialize_node(&mut FailRead).unwrap_err();
330 assert!(matches!(e, Error::Io(_)));
331 }
332}