skyway_webrtc_gateway_api/common/
formats.rs

1use std::fmt;
2use std::net::{IpAddr, SocketAddr};
3
4use serde::de::{self, Deserializer, MapAccess, Visitor};
5use serde::ser::{SerializeStruct, Serializer};
6use serde::{Deserialize, Serialize};
7
8use crate::error;
9
10/// This trait is for serializing ID to JSON.
11///
12/// It also has some getter functions.
13pub trait SerializableId: Clone {
14    /// Try to create an instance of SerializableId with String parameter.
15    ///
16    /// It returns None, if id is None.
17    fn try_create(id: impl Into<String>) -> Result<Self, error::Error>
18    where
19        Self: Sized;
20    /// Get internal str of the Id.
21    fn as_str(&self) -> &str;
22    /// Get internal String of the Id
23    fn id(&self) -> String;
24    /// Field name of Json. If it returns `"hoge_id"`, json will be `{"hoge_id": self.id()}`.
25    fn key(&self) -> &'static str;
26}
27
28/// This trait is for serializing SocketInfo to JSON.
29///
30/// It also has some getter functions.
31pub trait SerializableSocket<T> {
32    /// Create an instance of SerializableSocket.
33    ///
34    /// # Failures
35    /// It returns error, if the ip and port is not valid for SocketAddr.
36    fn try_create(id: Option<String>, ip: &str, port: u16) -> Result<Self, error::Error>
37    where
38        Self: Sized;
39    /// Returns id field.
40    fn get_id(&self) -> Option<T>;
41    /// Field name of Json.
42    fn key(&self) -> &'static str;
43    /// Returns SocketAddr of the socket.
44    fn addr(&self) -> &SocketAddr;
45    /// Returns IpAddr of the socket.
46    fn ip(&self) -> IpAddr;
47    /// Returns port number of the socket.
48    fn port(&self) -> u16;
49}
50
51/// There are several field which has some kind of id and SocketAddr.
52///
53/// This struct covers all of them.
54#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord, Hash)]
55pub struct SocketInfo<T: SerializableId> {
56    id: Option<T>,
57    socket: SocketAddr,
58}
59
60impl<T: SerializableId> SerializableSocket<T> for SocketInfo<T> {
61    fn try_create(id: Option<String>, ip: &str, port: u16) -> Result<Self, error::Error> {
62        let ip: IpAddr = ip.parse()?;
63        let socket = SocketAddr::new(ip, port);
64        match id {
65            Some(id) => Ok(Self {
66                id: Some(T::try_create(id)?),
67                socket: socket,
68            }),
69            None => Ok(Self {
70                id: None,
71                socket: socket,
72            }),
73        }
74    }
75
76    fn get_id(&self) -> Option<T> {
77        self.id.clone()
78    }
79
80    fn key(&self) -> &'static str {
81        match self.id {
82            Some(ref id) => id.key(),
83            None => "",
84        }
85    }
86
87    fn addr(&self) -> &SocketAddr {
88        &self.socket
89    }
90
91    fn ip(&self) -> IpAddr {
92        self.socket.ip()
93    }
94
95    fn port(&self) -> u16 {
96        self.socket.port()
97    }
98}
99
100impl<T: SerializableId> Serialize for SocketInfo<T> {
101    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
102    where
103        S: Serializer,
104    {
105        let key = self.key();
106        let id = self.get_id();
107        let mut serial;
108        if key.len() == 0 {
109            serial = serializer.serialize_struct("SocketAddr", 2)?
110        } else {
111            serial = serializer.serialize_struct("SocketAddr", 3)?;
112            serial.serialize_field(key, &(id.expect("no id")).id())?;
113        };
114
115        let ip = self.ip();
116        if ip.is_ipv4() {
117            serial.serialize_field("ip_v4", &ip.to_string())?;
118        } else {
119            serial.serialize_field("ip_v6", &ip.to_string())?;
120        }
121        serial.serialize_field("port", &self.port())?;
122        serial.end()
123    }
124}
125
126impl<'de, X: SerializableId> Deserialize<'de> for SocketInfo<X> {
127    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128    where
129        D: Deserializer<'de>,
130    {
131        use std::marker::PhantomData;
132        enum Field {
133            IP,
134            PORT,
135            ID,
136        }
137
138        impl<'de> Deserialize<'de> for Field {
139            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
140            where
141                D: Deserializer<'de>,
142            {
143                struct FieldVisitor;
144
145                impl<'de> Visitor<'de> for FieldVisitor {
146                    type Value = Field;
147
148                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
149                        formatter.write_str("`ip_v4` or `ip_v6` or `port` or `*_id`")
150                    }
151
152                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
153                    where
154                        E: de::Error,
155                    {
156                        match value {
157                            "ip_v4" => Ok(Field::IP),
158                            "ip_v6" => Ok(Field::IP),
159                            "port" => Ok(Field::PORT),
160                            id if id.ends_with("_id") => Ok(Field::ID),
161                            _ => Err(de::Error::unknown_field(value, FIELDS)),
162                        }
163                    }
164                }
165
166                deserializer.deserialize_identifier(FieldVisitor)
167            }
168        }
169
170        struct SocketInfoVisitor<T>(PhantomData<T>);
171
172        impl<'de, T: SerializableId> Visitor<'de> for SocketInfoVisitor<T> {
173            type Value = SocketInfo<T>;
174
175            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176                formatter.write_str("struct SocketAddr")
177            }
178
179            fn visit_map<V>(self, mut map: V) -> Result<SocketInfo<T>, V::Error>
180            where
181                V: MapAccess<'de>,
182            {
183                let mut ip: Option<String> = None;
184                let mut id: Option<String> = None;
185                let mut port: Option<u16> = None;
186                while let Some(key) = map.next_key()? {
187                    match key {
188                        Field::PORT => {
189                            if port.is_some() {
190                                return Err(de::Error::duplicate_field("port"));
191                            }
192                            port = Some(map.next_value()?);
193                        }
194                        Field::IP => {
195                            if ip.is_some() {
196                                return Err(de::Error::duplicate_field("ip_v4"));
197                            }
198                            ip = Some(map.next_value()?);
199                        }
200                        Field::ID => {
201                            if id.is_some() {
202                                return Err(de::Error::duplicate_field("id"));
203                            }
204                            id = Some(map.next_value()?);
205                        }
206                    }
207                }
208                let ip = ip.ok_or_else(|| de::Error::missing_field("ip_v4 or ip_v6"))?;
209                let port = port.ok_or_else(|| de::Error::missing_field("port"))?;
210                let socket_info = SocketInfo::<T>::try_create(id, &ip, port);
211                if let Err(_err) = socket_info {
212                    use serde::de::Error;
213                    return Err(V::Error::custom(format!("fail to deserialize socket")));
214                }
215
216                Ok(socket_info.unwrap())
217            }
218        }
219
220        const FIELDS: &'static [&'static str] = &["ip_v4", "ip_v6", "port", "*_id"];
221        deserializer.deserialize_struct("SocketAddr", FIELDS, SocketInfoVisitor(PhantomData))
222    }
223}
224
225/// It's just a dummy Id data returning None.
226///
227/// There are many similar structs holding SocketAddr and a kind of ID.
228/// PhantomId is for a struct which doesn't have id field.
229/// It will be set as a generics parameter of `SocketInfo`.
230#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
231pub struct PhantomId(String);
232
233impl SerializableId for PhantomId {
234    fn try_create(id: impl Into<String>) -> Result<Self, error::Error>
235    where
236        Self: Sized,
237    {
238        Ok(PhantomId(id.into()))
239    }
240
241    fn as_str(&self) -> &str {
242        ""
243    }
244
245    fn id(&self) -> String {
246        String::from("")
247    }
248
249    fn key(&self) -> &'static str {
250        ""
251    }
252}
253
254#[cfg(test)]
255mod test_socket_info {
256    use super::*;
257
258    #[test]
259    fn v4() {
260        let socket_info = SocketInfo::<PhantomId>::try_create(None, "127.0.0.1", 8000).unwrap();
261        let json = serde_json::to_string(&socket_info).expect("serialize failed");
262        let decoded_socket_info: SocketInfo<PhantomId> =
263            serde_json::from_str(&json).expect("deserialize failed");
264        assert_eq!(socket_info, decoded_socket_info);
265    }
266
267    #[test]
268    fn v6() {
269        let socket_info =
270            SocketInfo::<PhantomId>::try_create(None, "2001:DB8:0:0:8:800:200C:417A", 8000)
271                .unwrap();
272        let json = serde_json::to_string(&socket_info).expect("serialize failed");
273        let decoded_socket_info: SocketInfo<PhantomId> =
274            serde_json::from_str(&json).expect("deserialize failed");
275        assert_eq!(socket_info, decoded_socket_info);
276    }
277}