Skip to main content

tempest_core/encoding/
encoding_lexical.rs

1use std::string::FromUtf8Error;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5pub trait BufPutLexicalExt {
6    fn put_i64_lexical(&mut self, i: i64);
7    fn put_bool_lexical(&mut self, b: bool);
8    fn put_str_lexical(&mut self, s: &str);
9}
10
11impl BufPutLexicalExt for BytesMut {
12    fn put_i64_lexical(&mut self, i: i64) {
13        self.put_u64((i ^ i64::MIN) as u64);
14    }
15
16    fn put_bool_lexical(&mut self, b: bool) {
17        self.put_u8(b as u8);
18    }
19
20    fn put_str_lexical(&mut self, s: &str) {
21        for &c in s.as_bytes() {
22            match c {
23                0x00 => self.put_slice(&[0x00, 0xFF]),
24                _ => self.put_u8(c),
25            }
26        }
27        self.put_slice(&[0x00, 0x00]);
28    }
29}
30
31#[derive(Debug, Display, Error, From)]
32pub enum LexicalDecodeError {
33    UnexpectedEof,
34    FromUtf8Error(FromUtf8Error),
35}
36
37pub trait BufGetLexicalExt {
38    fn get_i64_lexical(&mut self) -> Result<i64, LexicalDecodeError>;
39    fn get_bool_lexical(&mut self) -> Result<bool, LexicalDecodeError>;
40    fn get_str_lexical(&mut self) -> Result<String, LexicalDecodeError>;
41}
42
43impl BufGetLexicalExt for Bytes {
44    fn get_i64_lexical(&mut self) -> Result<i64, LexicalDecodeError> {
45        if self.len() < 8 {
46            return Err(LexicalDecodeError::UnexpectedEof);
47        }
48        Ok((self.get_u64() as i64) ^ i64::MIN)
49    }
50
51    fn get_bool_lexical(&mut self) -> Result<bool, LexicalDecodeError> {
52        if self.is_empty() {
53            return Err(LexicalDecodeError::UnexpectedEof);
54        }
55        Ok(self.get_u8() != 0)
56    }
57
58    fn get_str_lexical(&mut self) -> Result<String, LexicalDecodeError> {
59        let mut pos = 0;
60        let mut result = Vec::with_capacity(64);
61        while pos + 1 < self.len() {
62            match (self[pos], self[pos + 1]) {
63                (0x00, 0x00) => {
64                    self.advance(pos + 2);
65                    return String::from_utf8(result).map_err(LexicalDecodeError::FromUtf8Error);
66                }
67                (0x00, 0xFF) => {
68                    result.push(0);
69                    pos += 2;
70                }
71                (c, _) => {
72                    // NB: silently recover orphan null bytes
73                    result.push(c);
74                    pos += 1;
75                }
76            }
77        }
78        Err(LexicalDecodeError::UnexpectedEof)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use bytes::{Bytes, BytesMut};
85
86    use super::*;
87
88    // -- helpers --
89
90    fn encode_i64(i: i64) -> Bytes {
91        let mut buf = BytesMut::new();
92        buf.put_i64_lexical(i);
93        buf.freeze()
94    }
95
96    fn encode_bool(b: bool) -> Bytes {
97        let mut buf = BytesMut::new();
98        buf.put_bool_lexical(b);
99        buf.freeze()
100    }
101
102    fn encode_str(s: &str) -> Bytes {
103        let mut buf = BytesMut::new();
104        buf.put_str_lexical(s);
105        buf.freeze()
106    }
107
108    // -- i64 round-trip --
109
110    #[test]
111    fn test_i64_roundtrip() {
112        for val in [0i64, 1, -1, i64::MIN, i64::MAX, -1000, 1000, 42] {
113            let mut bytes = encode_i64(val);
114            assert_eq!(
115                bytes.get_i64_lexical().unwrap(),
116                val,
117                "roundtrip failed for {}",
118                val
119            );
120        }
121    }
122
123    #[test]
124    fn test_i64_ordering() {
125        let cases = [i64::MIN, -1000, -1, 0, 1, 1000, i64::MAX];
126        for pair in cases.windows(2) {
127            let (a, b) = (pair[0], pair[1]);
128            assert!(
129                encode_i64(a) < encode_i64(b),
130                "{} should encode less than {}",
131                a,
132                b
133            );
134        }
135    }
136
137    #[test]
138    fn test_i64_eof() {
139        let mut bytes = Bytes::from_static(&[0x00, 0x00]); // only 2 bytes, need 8
140        assert!(matches!(
141            bytes.get_i64_lexical(),
142            Err(LexicalDecodeError::UnexpectedEof)
143        ));
144    }
145
146    // -- bool round-trip --
147
148    #[test]
149    fn test_bool_roundtrip() {
150        for val in [true, false] {
151            let mut bytes = encode_bool(val);
152            assert_eq!(bytes.get_bool_lexical().unwrap(), val);
153        }
154    }
155
156    #[test]
157    fn test_bool_ordering() {
158        assert!(encode_bool(false) < encode_bool(true));
159    }
160
161    #[test]
162    fn test_bool_eof() {
163        let mut bytes = Bytes::new();
164        assert!(matches!(
165            bytes.get_bool_lexical(),
166            Err(LexicalDecodeError::UnexpectedEof)
167        ));
168    }
169
170    // -- string round-trip --
171
172    #[test]
173    fn test_str_roundtrip() {
174        for val in ["", "hello", "hello world", "unicode: ??"] {
175            let mut bytes = encode_str(val);
176            assert_eq!(
177                bytes.get_str_lexical().unwrap(),
178                val,
179                "roundtrip failed for {:?}",
180                val
181            );
182        }
183    }
184
185    #[test]
186    fn test_str_with_null_bytes() {
187        let s = "hel\x00lo";
188        let mut bytes = encode_str(s);
189        assert_eq!(bytes.get_str_lexical().unwrap(), s);
190    }
191
192    #[test]
193    fn test_str_all_null_bytes() {
194        let s = "\x00\x00\x00";
195        let mut bytes = encode_str(s);
196        assert_eq!(bytes.get_str_lexical().unwrap(), s);
197    }
198
199    #[test]
200    fn test_str_ordering() {
201        let cases = ["", "a", "aa", "ab", "b", "z"];
202        for pair in cases.windows(2) {
203            let (a, b) = (pair[0], pair[1]);
204            assert!(
205                encode_str(a) < encode_str(b),
206                "{:?} should encode less than {:?}",
207                a,
208                b
209            );
210        }
211    }
212
213    #[test]
214    fn test_str_null_byte_ordering() {
215        // "\x00" < "a" lexicographically in the original string,
216        // and must remain so after encoding
217        assert!(encode_str("\x00") < encode_str("a"));
218        assert!(encode_str("a\x00b") < encode_str("a\x01b"));
219    }
220
221    #[test]
222    fn test_str_eof_no_terminator() {
223        // raw bytes with no \x00\x00 terminator
224        let mut bytes = Bytes::from_static(b"hello");
225        assert!(matches!(
226            bytes.get_str_lexical(),
227            Err(LexicalDecodeError::UnexpectedEof)
228        ));
229    }
230
231    #[test]
232    fn test_str_advances_cursor_correctly() {
233        // two strings back to back - cursor must land exactly after first terminator
234        let mut buf = BytesMut::new();
235        buf.put_str_lexical("foo");
236        buf.put_str_lexical("bar");
237        let mut bytes = buf.freeze();
238
239        assert_eq!(bytes.get_str_lexical().unwrap(), "foo");
240        assert_eq!(bytes.get_str_lexical().unwrap(), "bar");
241        assert!(bytes.is_empty());
242    }
243
244    #[test]
245    fn test_mixed_sequence() {
246        // encode a mix of types and decode them back in order
247        let mut buf = BytesMut::new();
248        buf.put_i64_lexical(-42);
249        buf.put_bool_lexical(true);
250        buf.put_str_lexical("tempest");
251        buf.put_i64_lexical(i64::MAX);
252
253        let mut bytes = buf.freeze();
254        assert_eq!(bytes.get_i64_lexical().unwrap(), -42);
255        assert_eq!(bytes.get_bool_lexical().unwrap(), true);
256        assert_eq!(bytes.get_str_lexical().unwrap(), "tempest");
257        assert_eq!(bytes.get_i64_lexical().unwrap(), i64::MAX);
258        assert!(bytes.is_empty());
259    }
260}