1use std::default::Default;
18use std::error::Error as ErrorTrait;
19use std::fmt::{Display, Formatter, Result as FmtResult};
20use std::io::{Error as IoError, Read, Result as IoResult};
21use std::marker::PhantomData;
22
23use bytes::{Buf, BufMut, BytesMut};
24use serde::{Deserialize, Serialize};
25use serde_cbor::de::{Deserializer, IoRead};
26use serde_cbor::error::Error as CborError;
27use serde_cbor::ser::{IoWrite, Serializer};
28use tokio_util::codec::{Decoder as IoDecoder, Encoder as IoEncoder};
29
30#[derive(Debug)]
32#[non_exhaustive]
33pub enum Error {
34 Io(IoError),
35 Cbor(CborError),
36}
37
38impl From<IoError> for Error {
39 fn from(error: IoError) -> Self {
40 Error::Io(error)
41 }
42}
43
44impl From<CborError> for Error {
45 fn from(error: CborError) -> Self {
46 Error::Cbor(error)
47 }
48}
49
50impl Display for Error {
51 fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
52 match self {
53 Error::Io(e) => e.fmt(fmt),
54 Error::Cbor(e) => e.fmt(fmt),
55 }
56 }
57}
58
59impl ErrorTrait for Error {
60 fn cause(&self) -> Option<&dyn ErrorTrait> {
61 match self {
62 Error::Io(e) => Some(e),
63 Error::Cbor(e) => Some(e),
64 }
65 }
66}
67
68struct Counted<'a, R: 'a> {
73 r: &'a mut R,
74 pos: &'a mut usize,
75}
76
77impl<'a, R: Read> Read for Counted<'a, R> {
78 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
79 match self.r.read(buf) {
80 Ok(size) => {
81 *self.pos += size;
82 Ok(size)
83 }
84 e => e,
85 }
86 }
87}
88
89#[derive(Clone, Debug)]
94pub struct Decoder<Item> {
95 _data: PhantomData<fn() -> Item>,
96}
97
98impl<'de, Item: Deserialize<'de>> Decoder<Item> {
99 pub fn new() -> Self {
101 Self { _data: PhantomData }
102 }
103}
104
105impl<'de, Item: Deserialize<'de>> Default for Decoder<Item> {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl<'de, Item: Deserialize<'de>> IoDecoder for Decoder<Item> {
112 type Item = Item;
113 type Error = Error;
114 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Item>, Error> {
115 let mut pos = 0;
116 let result = {
117 let mut slice: &[u8] = src;
118 let reader = Counted {
119 r: &mut slice,
120 pos: &mut pos,
121 };
122 let reader = IoRead::new(reader);
123 let mut deserializer = Deserializer::new(reader);
127 Item::deserialize(&mut deserializer)
128 };
129 match result {
130 Ok(item) => {
132 src.advance(pos);
133 Ok(Some(item))
134 }
135 Err(ref error) if error.is_eof() => Ok(None),
137 Err(e) => Err(e.into()),
139 }
140 }
141}
142
143#[derive(Clone, Debug, Eq, PartialEq)]
148pub enum SdMode {
149 Always,
151 Once,
153 Never,
155}
156
157#[derive(Clone, Debug)]
163pub struct Encoder<Item> {
164 _data: PhantomData<fn(Item)>,
165 sd: SdMode,
166 packed: bool,
167}
168
169impl<Item: Serialize> Encoder<Item> {
170 pub fn new() -> Self {
175 Self {
176 _data: PhantomData,
177 sd: SdMode::Never,
178 packed: false,
179 }
180 }
181 pub fn sd(self, sd: SdMode) -> Self {
183 Self { sd, ..self }
184 }
185 pub fn packed(self, packed: bool) -> Self {
191 Self { packed, ..self }
192 }
193}
194
195impl<Item: Serialize> Default for Encoder<Item> {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl<Item: Serialize> IoEncoder<Item> for Encoder<Item> {
202 type Error = Error;
203 fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Error> {
204 let mut serializer = if self.packed {
205 Serializer::new(IoWrite::new(dst.writer())).packed_format()
206 } else {
207 Serializer::new(IoWrite::new(dst.writer()))
208 };
209 if self.sd != SdMode::Never {
210 serializer.self_describe()?;
211 }
212 if self.sd == SdMode::Once {
213 self.sd = SdMode::Never;
214 }
215 item.serialize(&mut serializer).map_err(Into::into)
216 }
217}
218
219#[derive(Clone, Debug)]
223pub struct Codec<Dec, Enc> {
224 dec: Decoder<Dec>,
225 enc: Encoder<Enc>,
226}
227
228impl<'de, Dec: Deserialize<'de>, Enc: Serialize> Codec<Dec, Enc> {
229 pub fn new() -> Self {
231 Self {
232 dec: Decoder::new(),
233 enc: Encoder::new(),
234 }
235 }
236 pub fn sd(self, sd: SdMode) -> Self {
238 Self {
239 dec: self.dec,
240 enc: Encoder { sd, ..self.enc },
241 }
242 }
243 pub fn packed(self, packed: bool) -> Self {
249 Self {
250 dec: self.dec,
251 enc: Encoder { packed, ..self.enc },
252 }
253 }
254}
255
256impl<'de, Dec: Deserialize<'de>, Enc: Serialize> Default for Codec<Dec, Enc> {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262impl<'de, Dec: Deserialize<'de>, Enc: Serialize> IoDecoder for Codec<Dec, Enc> {
263 type Item = Dec;
264 type Error = Error;
265 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Dec>, Error> {
266 self.dec.decode(src)
267 }
268}
269
270impl<'de, Dec: Deserialize<'de>, Enc: Serialize> IoEncoder<Enc> for Codec<Dec, Enc> {
271 type Error = Error;
272 fn encode(&mut self, item: Enc, dst: &mut BytesMut) -> Result<(), Error> {
273 self.enc.encode(item, dst)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::collections::HashMap;
280 use std::sync::Arc;
281
282 use super::*;
283
284 type TestData = HashMap<String, usize>;
285
286 fn test_data() -> TestData {
288 let mut data = HashMap::new();
289 data.insert("hello".to_owned(), 42usize);
290 data.insert("world".to_owned(), 0usize);
291 data
292 }
293
294 fn decode<Dec: IoDecoder<Item = TestData, Error = Error>>(dec: Dec) {
296 let mut decoder = dec;
297 let data = test_data();
298 let encoded = serde_cbor::to_vec(&data).unwrap();
299 let mut all = BytesMut::with_capacity(128);
300 all.extend(&encoded);
302 all.extend(&encoded);
303 all.extend(&encoded[..1]);
304 let decoded = decoder.decode(&mut all).unwrap().unwrap();
306 assert_eq!(data, decoded);
307 let decoded = decoder.decode(&mut all).unwrap().unwrap();
308 assert_eq!(data, decoded);
309 assert_eq!(1, all.len());
311 assert!(decoder.decode(&mut all).unwrap().is_none());
313 assert_eq!(1, all.len());
315 all.extend(&encoded[1..]);
317 let decoded = decoder.decode(&mut all).unwrap().unwrap();
318 assert_eq!(data, decoded);
319 assert!(all.is_empty());
321 all.extend(&[0, 1, 2, 3, 4]);
323 decoder.decode(&mut all).unwrap_err();
324 assert_eq!(5, all.len());
326 }
327
328 #[test]
330 fn decode_only() {
331 let decoder = Decoder::new();
332 decode(decoder);
333 }
334
335 #[test]
337 fn decode_codec() {
338 let decoder: Codec<_, ()> = Codec::new();
339 decode(decoder);
340 }
341
342 fn encode<Enc: IoEncoder<TestData, Error = Error>>(enc: Enc) {
344 let mut encoder = enc;
345 let data = test_data();
346 let mut buffer = BytesMut::with_capacity(0);
347 encoder.encode(data.clone(), &mut buffer).unwrap();
348 let pos1 = buffer.len();
349 let decoded = serde_cbor::from_slice::<TestData>(&buffer).unwrap();
350 assert_eq!(data, decoded);
351 encoder.encode(data.clone(), &mut buffer).unwrap();
353 let pos2 = buffer.len();
354 assert!(pos2 > pos1);
356 assert!(pos1 * 2 > pos2);
358 let decoded = serde_cbor::from_slice::<TestData>(&buffer[pos1..]).unwrap();
360 assert_eq!(data, decoded);
361 encoder.encode(data, &mut buffer).unwrap();
363 let pos3 = buffer.len();
364 assert_eq!(pos2 - pos1, pos3 - pos2);
365 }
366
367 #[test]
369 fn encode_only() {
370 let encoder = Encoder::new().sd(SdMode::Once);
371 encode(encoder);
372 }
373
374 #[test]
376 fn encode_packed() {
377 let encoder = Encoder::new().packed(true).sd(SdMode::Once);
378 encode(encoder);
379 }
380
381 #[test]
383 fn encode_codec() {
384 let encoder: Codec<(), _> = Codec::new().sd(SdMode::Once);
385 encode(encoder);
386 }
387
388 #[test]
390 fn is_send() {
391 let codec: Codec<(), ()> = Codec::new();
392 std::thread::spawn(move || {
393 let _c = codec;
394 });
395 }
396
397 #[test]
399 fn is_sync() {
400 let codec: Arc<Codec<(), ()>> = Arc::new(Codec::new());
401 std::thread::spawn(move || {
402 let _c = codec;
403 });
404 }
405}