tower_duplex/
serialize.rs

1use std::marker::PhantomData;
2
3use serde::ser::SerializeTuple;
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5
6use super::DuplexValue;
7
8const RESPONSE_BIT: u8 = 1 << 7;
9
10impl<Request, Response> Serialize for DuplexValue<Request, Response>
11where
12    Request: Serialize,
13    Response: Serialize,
14{
15    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
16    where
17        S: Serializer,
18    {
19        let mut tup = serializer.serialize_tuple(2)?;
20        match self {
21            DuplexValue::Request(tag, request) => {
22                tup.serialize_element(&tag)?;
23                tup.serialize_element(request)?;
24            }
25            DuplexValue::Response(tag, response) => {
26                // We differentiate between request and response based on the MSB
27                tup.serialize_element(&(*tag | RESPONSE_BIT))?;
28                tup.serialize_element(response)?;
29            }
30        };
31        tup.end()
32    }
33}
34
35impl<'de, Request, Response> Deserialize<'de> for DuplexValue<Request, Response>
36where
37    for<'d> Request: Deserialize<'d>,
38    for<'d> Response: Deserialize<'d>,
39{
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: Deserializer<'de>,
43    {
44        struct DuplexValueVisitor<Request, Response> {
45            _req: PhantomData<Request>,
46            _res: PhantomData<Response>,
47        }
48
49        impl<'de, Request, Response> de::Visitor<'de> for DuplexValueVisitor<Request, Response>
50        where
51            Request: Deserialize<'de>,
52            Response: Deserialize<'de>,
53        {
54            type Value = DuplexValue<Request, Response>;
55
56            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
57                formatter.write_str("a tuple of u8 and request/response")
58            }
59
60            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
61            where
62                A: de::SeqAccess<'de>,
63                A::Error: de::Error,
64            {
65                let tag: u8 = match seq.next_element()? {
66                    Some(tag) => tag,
67                    None => return Err(de::Error::invalid_length(0, &self)),
68                };
69
70                // The top bit of the tag tells us if that is a response or a request
71                if tag & RESPONSE_BIT != 0 {
72                    let response: Response = match seq.next_element()? {
73                        Some(resp) => resp,
74                        None => return Err(de::Error::invalid_length(1, &self)),
75                    };
76                    Ok(DuplexValue::Response(tag & !RESPONSE_BIT, response))
77                } else {
78                    let request: Request = match seq.next_element()? {
79                        Some(req) => req,
80                        None => return Err(de::Error::invalid_length(1, &self)),
81                    };
82                    Ok(DuplexValue::Request(tag, request))
83                }
84            }
85        }
86
87        deserializer.deserialize_tuple(
88            2,
89            DuplexValueVisitor {
90                _req: Default::default(),
91                _res: Default::default(),
92            },
93        )
94    }
95}