Skip to main content

xtask_todo_lib/devshell/
serialization.rs

1//! .bin format: magic "DEVS" + version + cwd + root node tree.
2
3use std::io::{Cursor, Read, Write};
4
5use super::vfs::{Node, Vfs};
6
7const MAGIC: &[u8; 4] = b"DEVS";
8const VERSION: u8 = 1;
9
10const NODE_DIR: u8 = 0;
11const NODE_FILE: u8 = 1;
12
13#[derive(Debug)]
14pub enum Error {
15    InvalidMagic,
16    InvalidVersion,
17    Truncated,
18    InvalidUtf8(std::string::FromUtf8Error),
19    Io(std::io::Error),
20    /// Host-backed [`super::vfs::Vfs`] cannot be serialized to the legacy `.bin` format.
21    HostBacked,
22}
23
24impl std::fmt::Display for Error {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::InvalidMagic => write!(f, "invalid magic"),
28            Self::InvalidVersion => write!(f, "invalid version"),
29            Self::Truncated => write!(f, "truncated data"),
30            Self::InvalidUtf8(e) => write!(f, "invalid utf-8: {e}"),
31            Self::Io(e) => write!(f, "io error: {e}"),
32            Self::HostBacked => {
33                f.write_str("host-backed workspace cannot be saved as .dev_shell.bin")
34            }
35        }
36    }
37}
38
39impl std::error::Error for Error {
40    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
41        match self {
42            Self::InvalidUtf8(e) => Some(e),
43            Self::Io(e) => Some(e),
44            _ => None,
45        }
46    }
47}
48
49impl From<std::io::Error> for Error {
50    fn from(e: std::io::Error) -> Self {
51        Self::Io(e)
52    }
53}
54
55fn write_u32_le(w: &mut impl Write, n: u32) -> std::io::Result<()> {
56    w.write_all(&n.to_le_bytes())
57}
58
59fn write_u16_le(w: &mut impl Write, n: u16) -> std::io::Result<()> {
60    w.write_all(&n.to_le_bytes())
61}
62
63fn write_u64_le(w: &mut impl Write, n: u64) -> std::io::Result<()> {
64    w.write_all(&n.to_le_bytes())
65}
66
67fn read_u32_le(r: &mut impl Read) -> std::io::Result<u32> {
68    let mut buf = [0u8; 4];
69    r.read_exact(&mut buf)?;
70    Ok(u32::from_le_bytes(buf))
71}
72
73fn read_u16_le(r: &mut impl Read) -> std::io::Result<u16> {
74    let mut buf = [0u8; 2];
75    r.read_exact(&mut buf)?;
76    Ok(u16::from_le_bytes(buf))
77}
78
79fn read_u64_le(r: &mut impl Read) -> std::io::Result<u64> {
80    let mut buf = [0u8; 8];
81    r.read_exact(&mut buf)?;
82    Ok(u64::from_le_bytes(buf))
83}
84
85fn serialize_node(w: &mut impl Write, node: &Node) -> std::io::Result<()> {
86    match node {
87        Node::Dir { name, children } => {
88            w.write_all(&[NODE_DIR])?;
89            let name_bytes = name.as_bytes();
90            let len_u16 = u16::try_from(name_bytes.len())
91                .map_err(|_| std::io::Error::other("name len overflow"))?;
92            write_u16_le(w, len_u16)?;
93            w.write_all(name_bytes)?;
94            let len_u32 = u32::try_from(children.len())
95                .map_err(|_| std::io::Error::other("children len overflow"))?;
96            write_u32_le(w, len_u32)?;
97            for child in children {
98                serialize_node(w, child)?;
99            }
100        }
101        Node::File { name, content } => {
102            w.write_all(&[NODE_FILE])?;
103            let name_bytes = name.as_bytes();
104            let len_u16 = u16::try_from(name_bytes.len())
105                .map_err(|_| std::io::Error::other("name len overflow"))?;
106            write_u16_le(w, len_u16)?;
107            w.write_all(name_bytes)?;
108            write_u64_le(w, content.len() as u64)?;
109            w.write_all(content)?;
110        }
111    }
112    Ok(())
113}
114
115fn deserialize_node(r: &mut impl Read) -> Result<Node, Error> {
116    let mut tag = [0u8; 1];
117    r.read_exact(&mut tag).map_err(|e| {
118        if e.kind() == std::io::ErrorKind::UnexpectedEof {
119            Error::Truncated
120        } else {
121            Error::Io(e)
122        }
123    })?;
124    let name_len = read_u16_le(r)?;
125    let mut name_buf = vec![0u8; name_len as usize];
126    r.read_exact(&mut name_buf)?;
127    let name = String::from_utf8(name_buf).map_err(Error::InvalidUtf8)?;
128
129    match tag[0] {
130        NODE_DIR => {
131            let child_count = read_u32_le(r)?;
132            let n = usize::try_from(child_count).map_err(|_| Error::Truncated)?;
133            let mut children = Vec::with_capacity(n);
134            for _ in 0..child_count {
135                children.push(deserialize_node(r)?);
136            }
137            Ok(Node::Dir { name, children })
138        }
139        NODE_FILE => {
140            let content_len = read_u64_le(r)?;
141            let n = usize::try_from(content_len).map_err(|_| Error::Truncated)?;
142            let mut content = vec![0u8; n];
143            r.read_exact(&mut content)?;
144            Ok(Node::File { name, content })
145        }
146        _ => Err(Error::Truncated),
147    }
148}
149
150/// Serialize VFS to .bin format: DEVS magic + version 1 + cwd + root node tree.
151///
152/// # Errors
153/// Returns `Error::Io` on write failure or if cwd/name/children length overflows the wire format.
154pub fn serialize(vfs: &Vfs) -> Result<Vec<u8>, Error> {
155    if vfs.is_host_backed() {
156        return Err(Error::HostBacked);
157    }
158    let mut out = Vec::new();
159    out.write_all(MAGIC)?;
160    out.write_all(&[VERSION])?;
161    let cwd = vfs.cwd().as_bytes();
162    let cwd_len = u32::try_from(cwd.len())
163        .map_err(|_| Error::Io(std::io::Error::other("cwd len overflow")))?;
164    write_u32_le(&mut out, cwd_len)?;
165    out.write_all(cwd)?;
166    serialize_node(&mut out, vfs.root())?;
167    Ok(out)
168}
169
170/// Deserialize VFS from .bin format.
171///
172/// # Errors
173/// Returns `Error` on invalid magic/version, truncated data, or invalid UTF-8.
174pub fn deserialize(bytes: &[u8]) -> Result<Vfs, Error> {
175    let mut r = Cursor::new(bytes);
176    let mut magic = [0u8; 4];
177    r.read_exact(&mut magic).map_err(|_| Error::Truncated)?;
178    if &magic != MAGIC {
179        return Err(Error::InvalidMagic);
180    }
181    let mut ver = [0u8; 1];
182    r.read_exact(&mut ver).map_err(|_| Error::Truncated)?;
183    if ver[0] != VERSION {
184        return Err(Error::InvalidVersion);
185    }
186    let cwd_len = read_u32_le(&mut r)?;
187    let cwd_len_usize = usize::try_from(cwd_len).map_err(|_| Error::Truncated)?;
188    let mut cwd_buf = vec![0u8; cwd_len_usize];
189    r.read_exact(&mut cwd_buf)?;
190    let cwd = String::from_utf8(cwd_buf).map_err(Error::InvalidUtf8)?;
191    let root = deserialize_node(&mut r)?;
192    Ok(Vfs::from_parts(root, cwd))
193}
194
195/// Save VFS to a .bin file.
196///
197/// # Errors
198/// Returns I/O error on serialize or write failure.
199pub fn save_to_file(vfs: &Vfs, path: &std::path::Path) -> std::io::Result<()> {
200    let bytes = serialize(vfs).map_err(std::io::Error::other)?;
201    std::fs::write(path, bytes)
202}
203
204/// Load VFS from a .bin file.
205///
206/// # Errors
207/// Returns I/O error on read failure or invalid/corrupt .bin data.
208pub fn load_from_file(path: &std::path::Path) -> std::io::Result<Vfs> {
209    let bytes = std::fs::read(path)?;
210    deserialize(&bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn roundtrip_empty_vfs() {
219        let vfs = Vfs::new();
220        let bytes = serialize(&vfs).unwrap();
221        let vfs2 = deserialize(&bytes).unwrap();
222        assert_eq!(vfs.cwd(), vfs2.cwd());
223    }
224
225    #[test]
226    fn roundtrip_with_dir_and_file() {
227        let mut vfs = Vfs::new();
228        vfs.mkdir("/foo").unwrap();
229        vfs.write_file("/foo/bar", b"content").unwrap();
230        let bytes = serialize(&vfs).unwrap();
231        let vfs2 = deserialize(&bytes).unwrap();
232        assert_eq!(vfs2.read_file("/foo/bar").unwrap(), b"content");
233    }
234
235    #[test]
236    fn invalid_magic() {
237        let bytes = b"XXXX\x01\x00\x00\x00\x01/\x00\x00\x00\x00";
238        assert!(matches!(deserialize(bytes), Err(Error::InvalidMagic)));
239    }
240
241    #[test]
242    fn error_display() {
243        assert_eq!(Error::InvalidMagic.to_string(), "invalid magic");
244        assert_eq!(Error::InvalidVersion.to_string(), "invalid version");
245        assert_eq!(Error::Truncated.to_string(), "truncated data");
246        assert!(Error::Io(std::io::Error::other("e"))
247            .to_string()
248            .contains("io error"));
249        let utf8_err = Vec::<u8>::from([0xff, 0xfe]);
250        let e = String::from_utf8(utf8_err).unwrap_err();
251        assert!(Error::InvalidUtf8(e).to_string().contains("utf-8"));
252    }
253
254    #[test]
255    fn error_source() {
256        use std::error::Error as _;
257        let utf8_err = Vec::<u8>::from([0xff, 0xfe]);
258        let e = Error::InvalidUtf8(String::from_utf8(utf8_err).unwrap_err());
259        assert!(e.source().is_some());
260        let e = Error::Io(std::io::Error::other("test"));
261        assert!(e.source().is_some());
262        assert!(Error::InvalidMagic.source().is_none());
263    }
264
265    #[test]
266    fn deserialize_invalid_tag() {
267        let mut bytes = serialize(&Vfs::new()).unwrap();
268        let root_tag_offset = 4 + 1 + 4 + 1; // MAGIC + VERSION + cwd_len + cwd
269        bytes[root_tag_offset] = 0xff; // invalid node tag
270        assert!(matches!(deserialize(&bytes), Err(Error::Truncated)));
271    }
272
273    #[test]
274    fn invalid_version() {
275        let mut bytes = vec![b'D', b'E', b'V', b'S', 99, 0, 0, 0, 1, b'/'];
276        bytes.extend_from_slice(&0u32.to_le_bytes());
277        assert!(matches!(deserialize(&bytes), Err(Error::InvalidVersion)));
278    }
279
280    #[test]
281    fn load_from_file_nonexistent() {
282        let path = std::path::Path::new("/nonexistent_devshell_path_12345");
283        let r = load_from_file(path);
284        assert!(r.is_err());
285    }
286
287    #[test]
288    fn error_from_io() {
289        let e: Error = std::io::Error::other("test").into();
290        assert!(matches!(e, Error::Io(_)));
291    }
292
293    /// Truncated data after header: no byte for root node tag; `deserialize_node` `read_exact` fails (Truncated or Io).
294    #[test]
295    fn deserialize_truncated_after_cwd() {
296        let bytes = vec![b'D', b'E', b'V', b'S', 1, 0, 0, 0, 1, b'/']; // MAGIC + VERSION + cwd_len=1 + cwd="/"
297        let r = deserialize(&bytes);
298        assert!(r.is_err(), "expected Err (Truncated or Io)");
299    }
300
301    /// Truncated inside `deserialize_node`: root node tag and `name_len` read, then `child_count` read fails (Truncated or Io).
302    #[test]
303    fn deserialize_truncated_inside_node() {
304        // MAGIC + VERSION + cwd_len=0 + NODE_DIR=0 + name_len=0 (2 bytes) -> then read_u32_le needs 4 bytes, we have none
305        let bytes: Vec<u8> = [b'D', b'E', b'V', b'S', 1, 0, 0, 0, 0, 0, 0, 0]
306            .into_iter()
307            .collect();
308        let r = deserialize(&bytes);
309        assert!(r.is_err(), "expected Err (Truncated or Io)");
310    }
311
312    /// First byte of a node: EOF maps to `Error::Truncated` (covers `UnexpectedEof` branch in `deserialize_node`).
313    #[test]
314    fn deserialize_node_tag_unexpected_eof_is_truncated() {
315        let mut r = Cursor::new(&[]);
316        let e = deserialize_node(&mut r).unwrap_err();
317        assert!(matches!(e, Error::Truncated));
318    }
319
320    /// Non-EOF I/O errors from reading the node tag map to `Error::Io`.
321    #[test]
322    fn deserialize_node_tag_other_io_error_is_io() {
323        struct FailRead;
324        impl std::io::Read for FailRead {
325            fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
326                Err(std::io::Error::other("injected read failure"))
327            }
328        }
329        let e = deserialize_node(&mut FailRead).unwrap_err();
330        assert!(matches!(e, Error::Io(_)));
331    }
332}