Skip to main content

rustolio_utils/bytes/encoding/
mod.rs

1//
2// SPDX-License-Identifier: MPL-2.0
3//
4// Copyright (c) 2026 Tobias Binnewies. All rights reserved.
5//
6// This Source Code Form is subject to the terms of the Mozilla Public
7// License, v. 2.0. If a copy of the MPL was not distributed with this
8// file, You can obtain one at http://mozilla.org/MPL/2.0/.
9//
10
11mod decode;
12mod encode;
13
14use std::{any::type_name, mem};
15
16use crate::{
17    bytes::{Bytes, BytesMut},
18    prelude::*,
19};
20
21use bytes::{Buf, BufMut};
22pub use decode::{AsyncDecoder, Decode, Decoder};
23pub use encode::{AsyncEncoder, Encode, Encoder};
24pub use rustolio_utils_macros::{Decode, Encode, async_impl};
25
26pub fn encode_to_bytes(v: &impl Encode) -> crate::Result<Bytes> {
27    struct BytesWriter(BytesMut);
28
29    impl Encoder for BytesWriter {
30        fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
31            self.0.put_u8(byte);
32            Ok(())
33        }
34        fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
35            self.0.extend_from_slice(bytes);
36            Ok(())
37        }
38    }
39
40    let size = v.encode_size();
41    let mut writer = BytesWriter(BytesMut::with_capacity(size));
42    v.encode(&mut writer)?;
43    let bytes = writer.0.freeze();
44
45    #[cfg(test)]
46    assert_eq!(bytes.len(), size);
47
48    Ok(bytes)
49}
50
51pub fn decode_from_bytes<T: Decode>(bytes: Bytes) -> crate::Result<T> {
52    struct BytesReader(Bytes);
53
54    impl Decoder for BytesReader {
55        fn read_u8(&mut self) -> crate::Result<u8> {
56            Ok(self.0.get_u8())
57        }
58        fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
59            self.0.try_copy_to_slice(buf).context("UnexpectedEof")
60        }
61    }
62
63    let mut reader = BytesReader(bytes);
64    let v: T = T::decode(&mut reader)?;
65    Ok(v)
66}
67
68pub fn encode_to_std<W>(v: &impl Encode, writer: &mut W) -> crate::Result<()>
69where
70    W: std::io::Write + Unpin,
71{
72    struct StdWriter<W>(W);
73
74    impl<W> Encoder for StdWriter<W>
75    where
76        W: std::io::Write + Unpin,
77    {
78        fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
79            std::io::Write::write_all(&mut self.0, &[byte]).context("Failed to write u8")
80        }
81        fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
82            std::io::Write::write_all(&mut self.0, bytes).context("Failed to write all")
83        }
84    }
85
86    let mut writer = StdWriter(writer);
87    v.encode(&mut writer)?;
88    Ok(())
89}
90
91pub fn decode_from_std<R, T: Decode>(reader: &mut R) -> crate::Result<T>
92where
93    R: std::io::Read + Unpin,
94{
95    struct StdReader<R>(R);
96
97    impl<R> Decoder for StdReader<R>
98    where
99        R: std::io::Read + Unpin,
100    {
101        fn read_u8(&mut self) -> crate::Result<u8> {
102            let mut buf = [0];
103            std::io::Read::read_exact(&mut self.0, &mut buf).context("Failed to read u8")?;
104            Ok(buf[0])
105        }
106        fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
107            std::io::Read::read_exact(&mut self.0, buf).context("Failed to read exact")
108        }
109    }
110
111    let mut reader = StdReader(reader);
112    let v: T = T::decode(&mut reader)?;
113    Ok(v)
114}
115
116#[cfg(not(target_arch = "wasm32"))]
117pub async fn encode_to_tokio<W>(v: &impl Encode, writer: &mut W) -> crate::Result<()>
118where
119    W: tokio::io::AsyncWriteExt + Unpin,
120{
121    struct TokioWriter<W>(W);
122
123    impl<W> AsyncEncoder for TokioWriter<W>
124    where
125        W: tokio::io::AsyncWriteExt + Unpin,
126    {
127        async fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
128            tokio::io::AsyncWriteExt::write_u8(&mut self.0, byte)
129                .await
130                .context("Failed to write u8")
131        }
132        async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
133            tokio::io::AsyncWriteExt::write_all(&mut self.0, bytes)
134                .await
135                .context("Failed to write all")
136        }
137    }
138
139    let mut writer = TokioWriter(writer);
140    v.encode_async(&mut writer).await?;
141    Ok(())
142}
143
144#[cfg(not(target_arch = "wasm32"))]
145pub async fn decode_from_tokio<R, T: Decode>(reader: &mut R) -> crate::Result<T>
146where
147    R: tokio::io::AsyncReadExt + Unpin,
148{
149    struct TokioReader<R>(R);
150
151    impl<R> AsyncDecoder for TokioReader<R>
152    where
153        R: tokio::io::AsyncReadExt + Unpin,
154    {
155        async fn read_u8(&mut self) -> crate::Result<u8> {
156            tokio::io::AsyncReadExt::read_u8(&mut self.0)
157                .await
158                .context("Failed to read u8")
159        }
160        async fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
161            tokio::io::AsyncReadExt::read_exact(&mut self.0, buf)
162                .await
163                .map(|_| ())
164                .context("Failed to read exact")
165        }
166    }
167
168    let mut reader = TokioReader(reader);
169    let v: T = T::decode_async(&mut reader).await?;
170    Ok(v)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
174pub async fn decode_from_hyper_stream<T, S, B>(stream: S) -> crate::Result<T>
175where
176    T: Decode,
177    S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
178    B: hyper::body::Body + Unpin,
179{
180    struct StreamReader<S, B>(S, BytesMut, std::marker::PhantomData<B>);
181
182    impl<S, B> StreamReader<S, B>
183    where
184        S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
185        B: hyper::body::Body + Unpin,
186    {
187        async fn fill_buf(&mut self, until: usize) -> crate::Result<()> {
188            while self.1.remaining() < until {
189                let Some(n) = self.0.next().await else {
190                    return Err(crate::Error::new("UnexpectedEof"));
191                };
192                let mut buf = match n {
193                    Ok(n) => n,
194                    Err(_) => {
195                        return Err(crate::Error::new("UnexpectedEof"));
196                    }
197                };
198                if buf.remaining() == 0 {
199                    return Err(crate::Error::new("UnexpectedEof"));
200                }
201                while buf.remaining() > 0 {
202                    let chunk = buf.chunk();
203                    self.1.put_slice(chunk);
204                    buf.advance(chunk.len());
205                }
206            }
207            Ok(())
208        }
209    }
210
211    impl<S, B> AsyncDecoder for StreamReader<S, B>
212    where
213        S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
214        B: hyper::body::Body + Unpin,
215    {
216        async fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
217            self.fill_buf(buf.len()).await?;
218            self.1.copy_to_slice(buf);
219            Ok(())
220        }
221        async fn read_u8(&mut self) -> crate::Result<u8> {
222            self.fill_buf(1).await?;
223            Ok(self.1.get_u8())
224        }
225    }
226
227    let mut reader: StreamReader<S, B> =
228        StreamReader(stream, BytesMut::new(), std::marker::PhantomData);
229    T::decode_async(&mut reader).await
230}
231
232fn type_eq<T1, T2>() -> bool {
233    type_name::<T1>() == type_name::<T2>()
234}
235
236unsafe fn transmute_slice<F, T>(slice: &[F]) -> &[T] { unsafe {
237    assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
238    assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
239    mem::transmute(slice)
240}}
241
242unsafe fn transmute_vec<F, T>(vec: Vec<F>) -> Vec<T> { unsafe {
243    assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
244    assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
245    mem::transmute(vec)
246}}
247
248unsafe fn transmute_array<F, T, const N: usize>(arr: [F; N]) -> [T; N] { unsafe {
249    assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
250    assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
251    let ptr = &arr as *const [F; N] as *const [T; N];
252    ptr.read()
253}}
254
255#[cfg(test)]
256mod tests {
257    use std::{
258        collections::{HashMap, HashSet},
259        io::Write,
260    };
261
262    use crate::prelude::*;
263
264    use super::*;
265
266    #[test]
267    fn test_encoding_byte_vec() {
268        let v = vec![1, 2, 3];
269
270        let encoded = encode_to_bytes(&v).unwrap();
271        let vec: Vec<u8> = decode_from_bytes(encoded).unwrap();
272
273        assert_eq!(vec, v);
274    }
275
276    #[test]
277    fn test_encoding_vec() {
278        let v = vec!["test"];
279
280        let encoded = encode_to_bytes(&v).unwrap();
281        let vec: Vec<String> = decode_from_bytes(encoded).unwrap();
282
283        assert_eq!(vec, v);
284    }
285
286    #[test]
287    fn test_encoding_byte_array() {
288        let a = [1, 2, 3];
289
290        let encoded = encode_to_bytes(&a).unwrap();
291        let array: [u8; 3] = decode_from_bytes(encoded).unwrap();
292
293        assert_eq!(array, a);
294    }
295
296    #[test]
297    fn test_encoding_array() {
298        let a = [String::from("a")];
299
300        let encoded = encode_to_bytes(&a).unwrap();
301        let array: [String; 1] = decode_from_bytes(encoded).unwrap();
302
303        assert_eq!(array, a);
304    }
305
306    #[test]
307    fn test_encoding_vec_generic() {
308        let v: Vec<i32> = vec![1, 2, 3];
309
310        let encoded = encode_to_bytes(&v).unwrap();
311        let vec: Vec<i32> = decode_from_bytes(encoded).unwrap();
312
313        assert_eq!(vec, v);
314    }
315
316    #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
317    struct TestStruct {
318        f1: [u8; 32],
319        f2: Bytes,
320    }
321
322    #[test]
323    fn test_encoding_custom_struct() {
324        let t = TestStruct {
325            f1: [6; 32],
326            f2: Bytes::new(),
327        };
328
329        let encoded = encode_to_bytes(&t).unwrap();
330        let test: TestStruct = decode_from_bytes(encoded).unwrap();
331
332        assert_eq!(test, t);
333    }
334
335    #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
336    struct TestUnnamed([u8; 32], Bytes);
337
338    #[test]
339    fn test_encoding_custom_unnamed() {
340        let t = TestUnnamed([1; 32], Bytes::new());
341
342        let encoded = encode_to_bytes(&t).unwrap();
343        let test: TestUnnamed = decode_from_bytes(encoded).unwrap();
344
345        assert_eq!(test, t);
346    }
347
348    #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
349    enum TestEnum {
350        Unit,
351        Unnamed(TestUnnamed),
352        Named(TestStruct),
353    }
354
355    #[test]
356    fn test_encoding_custom_enum() {
357        let t = TestEnum::Unnamed(TestUnnamed([0; 32], Bytes::new()));
358
359        let encoded = encode_to_bytes(&t).unwrap();
360        let test: TestEnum = decode_from_bytes(encoded).unwrap();
361
362        assert_eq!(test, t);
363    }
364
365    #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
366    struct TestGeneric<D, const N: usize> {
367        f1: [u8; N],
368        f2: Bytes,
369        f3: D,
370    }
371
372    #[test]
373    fn test_encoding_custom_generic() {
374        let t = TestGeneric {
375            f1: [3; 32],
376            f2: Bytes::new(),
377            f3: String::from("test"),
378        };
379
380        let encoded = encode_to_bytes(&t).unwrap();
381        let test: TestGeneric<String, 32> = decode_from_bytes(encoded).unwrap();
382
383        assert_eq!(test, t);
384    }
385
386    #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
387    struct TestUnsend<D, const N: usize> {
388        f1: [u8; N],
389        f2: Bytes,
390        f3: D,
391        _maker: std::marker::PhantomData<std::rc::Rc<()>>, // !Send
392    }
393
394    #[tokio::test]
395    async fn test_encoding_tokio_unsend() {
396        // let handle = tokio::spawn(async {
397        let t = TestUnsend {
398            f1: [3; 32],
399            f2: Bytes::new(),
400            f3: String::from("test"),
401            _maker: std::marker::PhantomData,
402        };
403
404        let mut writer = std::io::Cursor::new(Vec::new());
405        encode_to_tokio(&t, &mut writer).await.unwrap();
406        writer.flush().unwrap();
407        let encoded = writer.into_inner();
408
409        let mut reader = std::io::Cursor::new(encoded);
410        let test: TestUnsend<String, 32> = decode_from_tokio(&mut reader).await.unwrap();
411
412        assert_eq!(test, t);
413        // });
414
415        // handle.await.unwrap();
416    }
417
418    #[tokio::test]
419    async fn test_encoding_tokio_send() {
420        let handle = tokio::spawn(async {
421            let t = TestGeneric {
422                f1: [3; 32],
423                f2: Bytes::new(),
424                f3: String::from("test"),
425            };
426
427            let mut writer = std::io::Cursor::new(Vec::new());
428            encode_to_tokio(&t, &mut writer).await.unwrap();
429            writer.flush().unwrap();
430            let encoded = writer.into_inner();
431
432            let mut reader = std::io::Cursor::new(encoded);
433            let test: TestGeneric<String, 32> = decode_from_tokio(&mut reader).await.unwrap();
434
435            assert_eq!(test, t);
436        });
437
438        handle.await.unwrap();
439    }
440
441    #[tokio::test]
442    async fn test_encoding_hash_map() {
443        let m = HashMap::from([(1u8, String::from("test")), (2u8, String::from("test1"))]);
444
445        let encoded = encode_to_bytes(&m).unwrap();
446        let map: HashMap<u8, String> = decode_from_bytes(encoded).unwrap();
447
448        assert_eq!(map, m);
449    }
450
451    #[tokio::test]
452    async fn test_encoding_hash_set() {
453        let s = HashSet::from([1u8, 2u8]);
454
455        let encoded = encode_to_bytes(&s).unwrap();
456        let set: HashSet<u8> = decode_from_bytes(encoded).unwrap();
457
458        assert_eq!(set, s);
459
460        let s = HashSet::from([String::from("test"), String::from("test1")]);
461
462        let encoded = encode_to_bytes(&s).unwrap();
463        let set: HashSet<String> = decode_from_bytes(encoded).unwrap();
464
465        assert_eq!(set, s);
466    }
467
468    #[derive(Encode, Decode)]
469    enum NoVariant {}
470
471    #[derive(Debug, PartialEq, Eq, Encode, Decode)]
472    enum Skipped {
473        First(String),
474
475        #[encode(skip)]
476        Second(String),
477
478        Third(String, #[encode(skip)] String, String),
479
480        Forth {
481            first: String,
482            #[encode(skip)]
483            second: String,
484            third: String,
485        },
486    }
487
488    #[test]
489    fn test_encoding_skip() {
490        let s1 = Skipped::First(String::from("first"));
491        let s2 = Skipped::Second(String::from("second"));
492        let s3 = Skipped::Third(
493            String::from("third1"),
494            String::from("third2"),
495            String::from("third3"),
496        );
497        let s4 = Skipped::Forth {
498            first: String::from("forth1"),
499            second: String::from("forth2"),
500            third: String::from("forth3"),
501        };
502
503        let en1 = encode_to_bytes(&s1).unwrap();
504        let en2 = encode_to_bytes(&s2).unwrap();
505        let en3 = encode_to_bytes(&s3).unwrap();
506        let en4 = encode_to_bytes(&s4).unwrap();
507
508        let skip1: Skipped = decode_from_bytes(en1).unwrap();
509        let skip2 = decode_from_bytes::<Skipped>(en2).unwrap_err();
510        let skip3: Skipped = decode_from_bytes(en3).unwrap();
511        let skip4: Skipped = decode_from_bytes(en4).unwrap();
512
513        assert_eq!(skip1, s1);
514        assert_eq!(skip2.context(), "Variant is skipped and cannot be decoded");
515        assert_eq!(
516            skip3,
517            Skipped::Third(
518                String::from("third1"),
519                String::from(""),
520                String::from("third3"),
521            )
522        );
523        assert_eq!(
524            skip4,
525            Skipped::Forth {
526                first: String::from("forth1"),
527                second: String::from(""),
528                third: String::from("forth3"),
529            }
530        );
531    }
532}