1mod decode;
12mod encode;
13
14use std::{any::type_name, mem};
15
16use crate::{
17 bytes::{Bytes, BytesMut},
18 prelude::*,
19};
20
21use bytes::{Buf, BufMut};
22pub use decode::{AsyncDecoder, Decode, Decoder};
23pub use encode::{AsyncEncoder, Encode, Encoder};
24pub use rustolio_utils_macros::{Decode, Encode, async_impl};
25
26pub fn encode_to_bytes(v: &impl Encode) -> crate::Result<Bytes> {
27 struct BytesWriter(BytesMut);
28
29 impl Encoder for BytesWriter {
30 fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
31 self.0.put_u8(byte);
32 Ok(())
33 }
34 fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
35 self.0.extend_from_slice(bytes);
36 Ok(())
37 }
38 }
39
40 let size = v.encode_size();
41 let mut writer = BytesWriter(BytesMut::with_capacity(size));
42 v.encode(&mut writer)?;
43 let bytes = writer.0.freeze();
44
45 #[cfg(test)]
46 assert_eq!(bytes.len(), size);
47
48 Ok(bytes)
49}
50
51pub fn decode_from_bytes<T: Decode>(bytes: Bytes) -> crate::Result<T> {
52 struct BytesReader(Bytes);
53
54 impl Decoder for BytesReader {
55 fn read_u8(&mut self) -> crate::Result<u8> {
56 Ok(self.0.get_u8())
57 }
58 fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
59 self.0.try_copy_to_slice(buf).context("UnexpectedEof")
60 }
61 }
62
63 let mut reader = BytesReader(bytes);
64 let v: T = T::decode(&mut reader)?;
65 Ok(v)
66}
67
68pub fn encode_to_std<W>(v: &impl Encode, writer: &mut W) -> crate::Result<()>
69where
70 W: std::io::Write + Unpin,
71{
72 struct StdWriter<W>(W);
73
74 impl<W> Encoder for StdWriter<W>
75 where
76 W: std::io::Write + Unpin,
77 {
78 fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
79 std::io::Write::write_all(&mut self.0, &[byte]).context("Failed to write u8")
80 }
81 fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
82 std::io::Write::write_all(&mut self.0, bytes).context("Failed to write all")
83 }
84 }
85
86 let mut writer = StdWriter(writer);
87 v.encode(&mut writer)?;
88 Ok(())
89}
90
91pub fn decode_from_std<R, T: Decode>(reader: &mut R) -> crate::Result<T>
92where
93 R: std::io::Read + Unpin,
94{
95 struct StdReader<R>(R);
96
97 impl<R> Decoder for StdReader<R>
98 where
99 R: std::io::Read + Unpin,
100 {
101 fn read_u8(&mut self) -> crate::Result<u8> {
102 let mut buf = [0];
103 std::io::Read::read_exact(&mut self.0, &mut buf).context("Failed to read u8")?;
104 Ok(buf[0])
105 }
106 fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
107 std::io::Read::read_exact(&mut self.0, buf).context("Failed to read exact")
108 }
109 }
110
111 let mut reader = StdReader(reader);
112 let v: T = T::decode(&mut reader)?;
113 Ok(v)
114}
115
116#[cfg(not(target_arch = "wasm32"))]
117pub async fn encode_to_tokio<W>(v: &impl Encode, writer: &mut W) -> crate::Result<()>
118where
119 W: tokio::io::AsyncWriteExt + Unpin,
120{
121 struct TokioWriter<W>(W);
122
123 impl<W> AsyncEncoder for TokioWriter<W>
124 where
125 W: tokio::io::AsyncWriteExt + Unpin,
126 {
127 async fn write_u8(&mut self, byte: u8) -> crate::Result<()> {
128 tokio::io::AsyncWriteExt::write_u8(&mut self.0, byte)
129 .await
130 .context("Failed to write u8")
131 }
132 async fn write_all(&mut self, bytes: &[u8]) -> crate::Result<()> {
133 tokio::io::AsyncWriteExt::write_all(&mut self.0, bytes)
134 .await
135 .context("Failed to write all")
136 }
137 }
138
139 let mut writer = TokioWriter(writer);
140 v.encode_async(&mut writer).await?;
141 Ok(())
142}
143
144#[cfg(not(target_arch = "wasm32"))]
145pub async fn decode_from_tokio<R, T: Decode>(reader: &mut R) -> crate::Result<T>
146where
147 R: tokio::io::AsyncReadExt + Unpin,
148{
149 struct TokioReader<R>(R);
150
151 impl<R> AsyncDecoder for TokioReader<R>
152 where
153 R: tokio::io::AsyncReadExt + Unpin,
154 {
155 async fn read_u8(&mut self) -> crate::Result<u8> {
156 tokio::io::AsyncReadExt::read_u8(&mut self.0)
157 .await
158 .context("Failed to read u8")
159 }
160 async fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
161 tokio::io::AsyncReadExt::read_exact(&mut self.0, buf)
162 .await
163 .map(|_| ())
164 .context("Failed to read exact")
165 }
166 }
167
168 let mut reader = TokioReader(reader);
169 let v: T = T::decode_async(&mut reader).await?;
170 Ok(v)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
174pub async fn decode_from_hyper_stream<T, S, B>(stream: S) -> crate::Result<T>
175where
176 T: Decode,
177 S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
178 B: hyper::body::Body + Unpin,
179{
180 struct StreamReader<S, B>(S, BytesMut, std::marker::PhantomData<B>);
181
182 impl<S, B> StreamReader<S, B>
183 where
184 S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
185 B: hyper::body::Body + Unpin,
186 {
187 async fn fill_buf(&mut self, until: usize) -> crate::Result<()> {
188 while self.1.remaining() < until {
189 let Some(n) = self.0.next().await else {
190 return Err(crate::Error::new("UnexpectedEof"));
191 };
192 let mut buf = match n {
193 Ok(n) => n,
194 Err(_) => {
195 return Err(crate::Error::new("UnexpectedEof"));
196 }
197 };
198 if buf.remaining() == 0 {
199 return Err(crate::Error::new("UnexpectedEof"));
200 }
201 while buf.remaining() > 0 {
202 let chunk = buf.chunk();
203 self.1.put_slice(chunk);
204 buf.advance(chunk.len());
205 }
206 }
207 Ok(())
208 }
209 }
210
211 impl<S, B> AsyncDecoder for StreamReader<S, B>
212 where
213 S: futures::StreamExt<Item = std::result::Result<B::Data, B::Error>> + Unpin,
214 B: hyper::body::Body + Unpin,
215 {
216 async fn read_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
217 self.fill_buf(buf.len()).await?;
218 self.1.copy_to_slice(buf);
219 Ok(())
220 }
221 async fn read_u8(&mut self) -> crate::Result<u8> {
222 self.fill_buf(1).await?;
223 Ok(self.1.get_u8())
224 }
225 }
226
227 let mut reader: StreamReader<S, B> =
228 StreamReader(stream, BytesMut::new(), std::marker::PhantomData);
229 T::decode_async(&mut reader).await
230}
231
232fn type_eq<T1, T2>() -> bool {
233 type_name::<T1>() == type_name::<T2>()
234}
235
236unsafe fn transmute_slice<F, T>(slice: &[F]) -> &[T] { unsafe {
237 assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
238 assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
239 mem::transmute(slice)
240}}
241
242unsafe fn transmute_vec<F, T>(vec: Vec<F>) -> Vec<T> { unsafe {
243 assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
244 assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
245 mem::transmute(vec)
246}}
247
248unsafe fn transmute_array<F, T, const N: usize>(arr: [F; N]) -> [T; N] { unsafe {
249 assert_eq!(mem::size_of::<F>(), mem::size_of::<T>());
250 assert_eq!(mem::align_of::<F>(), mem::align_of::<T>());
251 let ptr = &arr as *const [F; N] as *const [T; N];
252 ptr.read()
253}}
254
255#[cfg(test)]
256mod tests {
257 use std::{
258 collections::{HashMap, HashSet},
259 io::Write,
260 };
261
262 use crate::prelude::*;
263
264 use super::*;
265
266 #[test]
267 fn test_encoding_byte_vec() {
268 let v = vec![1, 2, 3];
269
270 let encoded = encode_to_bytes(&v).unwrap();
271 let vec: Vec<u8> = decode_from_bytes(encoded).unwrap();
272
273 assert_eq!(vec, v);
274 }
275
276 #[test]
277 fn test_encoding_vec() {
278 let v = vec!["test"];
279
280 let encoded = encode_to_bytes(&v).unwrap();
281 let vec: Vec<String> = decode_from_bytes(encoded).unwrap();
282
283 assert_eq!(vec, v);
284 }
285
286 #[test]
287 fn test_encoding_byte_array() {
288 let a = [1, 2, 3];
289
290 let encoded = encode_to_bytes(&a).unwrap();
291 let array: [u8; 3] = decode_from_bytes(encoded).unwrap();
292
293 assert_eq!(array, a);
294 }
295
296 #[test]
297 fn test_encoding_array() {
298 let a = [String::from("a")];
299
300 let encoded = encode_to_bytes(&a).unwrap();
301 let array: [String; 1] = decode_from_bytes(encoded).unwrap();
302
303 assert_eq!(array, a);
304 }
305
306 #[test]
307 fn test_encoding_vec_generic() {
308 let v: Vec<i32> = vec![1, 2, 3];
309
310 let encoded = encode_to_bytes(&v).unwrap();
311 let vec: Vec<i32> = decode_from_bytes(encoded).unwrap();
312
313 assert_eq!(vec, v);
314 }
315
316 #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
317 struct TestStruct {
318 f1: [u8; 32],
319 f2: Bytes,
320 }
321
322 #[test]
323 fn test_encoding_custom_struct() {
324 let t = TestStruct {
325 f1: [6; 32],
326 f2: Bytes::new(),
327 };
328
329 let encoded = encode_to_bytes(&t).unwrap();
330 let test: TestStruct = decode_from_bytes(encoded).unwrap();
331
332 assert_eq!(test, t);
333 }
334
335 #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
336 struct TestUnnamed([u8; 32], Bytes);
337
338 #[test]
339 fn test_encoding_custom_unnamed() {
340 let t = TestUnnamed([1; 32], Bytes::new());
341
342 let encoded = encode_to_bytes(&t).unwrap();
343 let test: TestUnnamed = decode_from_bytes(encoded).unwrap();
344
345 assert_eq!(test, t);
346 }
347
348 #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
349 enum TestEnum {
350 Unit,
351 Unnamed(TestUnnamed),
352 Named(TestStruct),
353 }
354
355 #[test]
356 fn test_encoding_custom_enum() {
357 let t = TestEnum::Unnamed(TestUnnamed([0; 32], Bytes::new()));
358
359 let encoded = encode_to_bytes(&t).unwrap();
360 let test: TestEnum = decode_from_bytes(encoded).unwrap();
361
362 assert_eq!(test, t);
363 }
364
365 #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
366 struct TestGeneric<D, const N: usize> {
367 f1: [u8; N],
368 f2: Bytes,
369 f3: D,
370 }
371
372 #[test]
373 fn test_encoding_custom_generic() {
374 let t = TestGeneric {
375 f1: [3; 32],
376 f2: Bytes::new(),
377 f3: String::from("test"),
378 };
379
380 let encoded = encode_to_bytes(&t).unwrap();
381 let test: TestGeneric<String, 32> = decode_from_bytes(encoded).unwrap();
382
383 assert_eq!(test, t);
384 }
385
386 #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
387 struct TestUnsend<D, const N: usize> {
388 f1: [u8; N],
389 f2: Bytes,
390 f3: D,
391 _maker: std::marker::PhantomData<std::rc::Rc<()>>, }
393
394 #[tokio::test]
395 async fn test_encoding_tokio_unsend() {
396 let t = TestUnsend {
398 f1: [3; 32],
399 f2: Bytes::new(),
400 f3: String::from("test"),
401 _maker: std::marker::PhantomData,
402 };
403
404 let mut writer = std::io::Cursor::new(Vec::new());
405 encode_to_tokio(&t, &mut writer).await.unwrap();
406 writer.flush().unwrap();
407 let encoded = writer.into_inner();
408
409 let mut reader = std::io::Cursor::new(encoded);
410 let test: TestUnsend<String, 32> = decode_from_tokio(&mut reader).await.unwrap();
411
412 assert_eq!(test, t);
413 }
417
418 #[tokio::test]
419 async fn test_encoding_tokio_send() {
420 let handle = tokio::spawn(async {
421 let t = TestGeneric {
422 f1: [3; 32],
423 f2: Bytes::new(),
424 f3: String::from("test"),
425 };
426
427 let mut writer = std::io::Cursor::new(Vec::new());
428 encode_to_tokio(&t, &mut writer).await.unwrap();
429 writer.flush().unwrap();
430 let encoded = writer.into_inner();
431
432 let mut reader = std::io::Cursor::new(encoded);
433 let test: TestGeneric<String, 32> = decode_from_tokio(&mut reader).await.unwrap();
434
435 assert_eq!(test, t);
436 });
437
438 handle.await.unwrap();
439 }
440
441 #[tokio::test]
442 async fn test_encoding_hash_map() {
443 let m = HashMap::from([(1u8, String::from("test")), (2u8, String::from("test1"))]);
444
445 let encoded = encode_to_bytes(&m).unwrap();
446 let map: HashMap<u8, String> = decode_from_bytes(encoded).unwrap();
447
448 assert_eq!(map, m);
449 }
450
451 #[tokio::test]
452 async fn test_encoding_hash_set() {
453 let s = HashSet::from([1u8, 2u8]);
454
455 let encoded = encode_to_bytes(&s).unwrap();
456 let set: HashSet<u8> = decode_from_bytes(encoded).unwrap();
457
458 assert_eq!(set, s);
459
460 let s = HashSet::from([String::from("test"), String::from("test1")]);
461
462 let encoded = encode_to_bytes(&s).unwrap();
463 let set: HashSet<String> = decode_from_bytes(encoded).unwrap();
464
465 assert_eq!(set, s);
466 }
467
468 #[derive(Encode, Decode)]
469 enum NoVariant {}
470
471 #[derive(Debug, PartialEq, Eq, Encode, Decode)]
472 enum Skipped {
473 First(String),
474
475 #[encode(skip)]
476 Second(String),
477
478 Third(String, #[encode(skip)] String, String),
479
480 Forth {
481 first: String,
482 #[encode(skip)]
483 second: String,
484 third: String,
485 },
486 }
487
488 #[test]
489 fn test_encoding_skip() {
490 let s1 = Skipped::First(String::from("first"));
491 let s2 = Skipped::Second(String::from("second"));
492 let s3 = Skipped::Third(
493 String::from("third1"),
494 String::from("third2"),
495 String::from("third3"),
496 );
497 let s4 = Skipped::Forth {
498 first: String::from("forth1"),
499 second: String::from("forth2"),
500 third: String::from("forth3"),
501 };
502
503 let en1 = encode_to_bytes(&s1).unwrap();
504 let en2 = encode_to_bytes(&s2).unwrap();
505 let en3 = encode_to_bytes(&s3).unwrap();
506 let en4 = encode_to_bytes(&s4).unwrap();
507
508 let skip1: Skipped = decode_from_bytes(en1).unwrap();
509 let skip2 = decode_from_bytes::<Skipped>(en2).unwrap_err();
510 let skip3: Skipped = decode_from_bytes(en3).unwrap();
511 let skip4: Skipped = decode_from_bytes(en4).unwrap();
512
513 assert_eq!(skip1, s1);
514 assert_eq!(skip2.context(), "Variant is skipped and cannot be decoded");
515 assert_eq!(
516 skip3,
517 Skipped::Third(
518 String::from("third1"),
519 String::from(""),
520 String::from("third3"),
521 )
522 );
523 assert_eq!(
524 skip4,
525 Skipped::Forth {
526 first: String::from("forth1"),
527 second: String::from(""),
528 third: String::from("forth3"),
529 }
530 );
531 }
532}