1#![doc = include_str!("../README.md")]
14
15use serde::{Serialize, de::DeserializeOwned};
16
17#[derive(Debug, thiserror::Error)]
24pub enum CodecError {
25 #[error("payload empty")]
27 Empty,
28 #[error("version mismatch: expected {expected}, got {actual}")]
32 Version { expected: u8, actual: u8 },
33 #[error("encode failed: {0}")]
35 Encode(#[source] postcard::Error),
36 #[error("decode failed: {0}")]
38 Decode(#[source] postcard::Error),
39 #[error("trailing bytes: {extra} unconsumed after a valid body")]
43 TrailingBytes { extra: usize },
44}
45
46pub fn encode<T: Serialize>(version: u8, value: &T) -> Result<Vec<u8>, CodecError> {
53 let body = postcard::to_stdvec(value).map_err(CodecError::Encode)?;
54 let mut out = Vec::with_capacity(1 + body.len());
55 out.push(version);
56 out.extend_from_slice(&body);
57 Ok(out)
58}
59
60pub fn decode<T: DeserializeOwned>(expected_version: u8, bytes: &[u8]) -> Result<T, CodecError> {
69 let (first, rest) = bytes.split_first().ok_or(CodecError::Empty)?;
70 if *first != expected_version {
71 return Err(CodecError::Version {
72 expected: expected_version,
73 actual: *first,
74 });
75 }
76 let (value, remainder) = postcard::take_from_bytes(rest).map_err(CodecError::Decode)?;
77 if !remainder.is_empty() {
78 return Err(CodecError::TrailingBytes {
79 extra: remainder.len(),
80 });
81 }
82 Ok(value)
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use serde::{Deserialize, Serialize};
89
90 #[derive(Debug, PartialEq, Serialize, Deserialize)]
91 struct Sample {
92 idx: u64,
93 name: String,
94 }
95
96 #[test]
97 fn encode_decode_roundtrip() {
98 let original = Sample {
99 idx: 42,
100 name: "tsoracle".into(),
101 };
102 let bytes = encode(1, &original).expect("encode");
103 assert_eq!(bytes[0], 1);
104 let decoded: Sample = decode(1, &bytes).expect("decode");
105 assert_eq!(original, decoded);
106 }
107
108 #[test]
109 fn decode_rejects_wrong_version() {
110 let bytes = encode(
111 2,
112 &Sample {
113 idx: 1,
114 name: "x".into(),
115 },
116 )
117 .expect("encode");
118 let err = decode::<Sample>(1, &bytes).expect_err("must reject");
119 assert!(matches!(
120 err,
121 CodecError::Version {
122 expected: 1,
123 actual: 2
124 }
125 ));
126 }
127
128 #[test]
129 fn decode_rejects_empty() {
130 let err = decode::<Sample>(1, &[]).expect_err("must reject");
131 assert!(matches!(err, CodecError::Empty));
132 }
133
134 #[test]
135 fn decode_rejects_truncated_input() {
136 let original = Sample {
137 idx: u64::MAX,
138 name: "hello-world-storage-roundtrip".into(),
139 };
140 let bytes = encode(1, &original).expect("encode");
141 assert!(bytes.len() >= 16, "payload should be non-trivial");
142 let truncated = &bytes[..bytes.len() / 2];
143 assert!(matches!(
144 decode::<Sample>(1, truncated),
145 Err(CodecError::Decode(_))
146 ));
147 }
148
149 #[test]
150 fn decode_rejects_trailing_bytes() {
151 let original = Sample {
152 idx: 7,
153 name: "trailing".into(),
154 };
155 let mut bytes = encode(1, &original).expect("encode");
156 bytes.extend_from_slice(&[0xAB, 0xCD, 0xEF]);
159 assert!(matches!(
160 decode::<Sample>(1, &bytes),
161 Err(CodecError::TrailingBytes { extra: 3 })
162 ));
163 }
164
165 use proptest::prelude::*;
166
167 proptest! {
168 #[test]
171 fn encode_decode_roundtrip_any(
172 version in any::<u8>(),
173 idx in any::<u64>(),
174 name in any::<String>(),
175 ) {
176 let s = Sample { idx, name };
177 let bytes = encode(version, &s).unwrap();
178 prop_assert_eq!(bytes[0], version);
179 let back: Sample = decode(version, &bytes).unwrap();
180 prop_assert_eq!(s, back);
181 }
182
183 #[test]
187 fn decode_rejects_any_version_mismatch(
188 encoded in any::<u8>(),
189 expected in any::<u8>(),
190 idx in any::<u64>(),
191 name in any::<String>(),
192 ) {
193 prop_assume!(encoded != expected);
194 let bytes = encode(encoded, &Sample { idx, name }).unwrap();
195 match decode::<Sample>(expected, &bytes) {
196 Err(CodecError::Version { expected: e, actual: a }) => {
197 prop_assert_eq!(e, expected);
198 prop_assert_eq!(a, encoded);
199 }
200 other => prop_assert!(false, "expected Version mismatch; got {other:?}"),
201 }
202 }
203 }
204}