serde_request_envelope/
request.rs

1use std::fmt;
2
3use serde::{
4    de::{self, Deserialize, Deserializer, MapAccess, Visitor},
5    ser::SerializeStruct,
6    Serialize,
7};
8
9use crate::support;
10
11/// A request envelope that includes the type name of the given type. Wrap your
12/// structs in this to get type tagged structures with data contents.
13#[derive(Clone, Debug)]
14pub struct Request<T>(pub T);
15
16impl<T> Serialize for Request<T>
17where
18    T: Serialize,
19{
20    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
21    where
22        S: serde::Serializer,
23    {
24        let mut serializer = serializer.serialize_struct("Request", 2)?;
25        serializer.serialize_field(
26            "type",
27            &support::type_name(&self.0).map_err(|_| serde::ser::Error::custom("not struct"))?,
28        )?;
29        serializer.serialize_field("data", &self.0)?;
30        serializer.end()
31    }
32}
33
34impl<T> Request<T> {
35    /// Create a new request envelope with the given data.
36    pub fn new(data: T) -> Self {
37        Self(data)
38    }
39}
40
41impl<'de, Data> Deserialize<'de> for Request<Data>
42where
43    Data: Deserialize<'de> + Serialize,
44{
45    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46    where
47        D: Deserializer<'de>,
48    {
49        struct RequestVisitor<Data> {
50            marker: std::marker::PhantomData<Data>,
51        }
52
53        impl<'de, Data> Visitor<'de> for RequestVisitor<Data>
54        where
55            Data: Deserialize<'de> + Serialize,
56        {
57            type Value = Request<Data>;
58
59            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
60                formatter.write_str("struct Request")
61            }
62
63            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
64            where
65                A: MapAccess<'de>,
66            {
67                let mut given_type: Option<String> = None;
68                let mut data = None;
69
70                while let Some(key) = map.next_key()? {
71                    match key {
72                        "type" => {
73                            given_type = Some(map.next_value()?);
74                        }
75                        "data" => {
76                            data = Some(map.next_value()?);
77                        }
78                        _ => {
79                            let _: serde::de::IgnoredAny = map.next_value()?;
80                        }
81                    }
82                }
83
84                let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
85                let given_type = given_type.ok_or_else(|| de::Error::missing_field("type"))?;
86
87                let expected_type =
88                    support::type_name(&data).map_err(|_| de::Error::custom("not struct"))?;
89
90                if expected_type != given_type {
91                    return Err(de::Error::custom(format!(
92                        "wrong type: expected {expected_type}, got {given_type}"
93                    )));
94                }
95                Ok(Request::new(data))
96            }
97        }
98
99        deserializer.deserialize_struct(
100            "Request",
101            &["type", "data"],
102            RequestVisitor {
103                marker: std::marker::PhantomData,
104            },
105        )
106    }
107}