skyway_webrtc_gateway_api/common/
formats.rs1use std::fmt;
2use std::net::{IpAddr, SocketAddr};
3
4use serde::de::{self, Deserializer, MapAccess, Visitor};
5use serde::ser::{SerializeStruct, Serializer};
6use serde::{Deserialize, Serialize};
7
8use crate::error;
9
10pub trait SerializableId: Clone {
14 fn try_create(id: impl Into<String>) -> Result<Self, error::Error>
18 where
19 Self: Sized;
20 fn as_str(&self) -> &str;
22 fn id(&self) -> String;
24 fn key(&self) -> &'static str;
26}
27
28pub trait SerializableSocket<T> {
32 fn try_create(id: Option<String>, ip: &str, port: u16) -> Result<Self, error::Error>
37 where
38 Self: Sized;
39 fn get_id(&self) -> Option<T>;
41 fn key(&self) -> &'static str;
43 fn addr(&self) -> &SocketAddr;
45 fn ip(&self) -> IpAddr;
47 fn port(&self) -> u16;
49}
50
51#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord, Hash)]
55pub struct SocketInfo<T: SerializableId> {
56 id: Option<T>,
57 socket: SocketAddr,
58}
59
60impl<T: SerializableId> SerializableSocket<T> for SocketInfo<T> {
61 fn try_create(id: Option<String>, ip: &str, port: u16) -> Result<Self, error::Error> {
62 let ip: IpAddr = ip.parse()?;
63 let socket = SocketAddr::new(ip, port);
64 match id {
65 Some(id) => Ok(Self {
66 id: Some(T::try_create(id)?),
67 socket: socket,
68 }),
69 None => Ok(Self {
70 id: None,
71 socket: socket,
72 }),
73 }
74 }
75
76 fn get_id(&self) -> Option<T> {
77 self.id.clone()
78 }
79
80 fn key(&self) -> &'static str {
81 match self.id {
82 Some(ref id) => id.key(),
83 None => "",
84 }
85 }
86
87 fn addr(&self) -> &SocketAddr {
88 &self.socket
89 }
90
91 fn ip(&self) -> IpAddr {
92 self.socket.ip()
93 }
94
95 fn port(&self) -> u16 {
96 self.socket.port()
97 }
98}
99
100impl<T: SerializableId> Serialize for SocketInfo<T> {
101 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
102 where
103 S: Serializer,
104 {
105 let key = self.key();
106 let id = self.get_id();
107 let mut serial;
108 if key.len() == 0 {
109 serial = serializer.serialize_struct("SocketAddr", 2)?
110 } else {
111 serial = serializer.serialize_struct("SocketAddr", 3)?;
112 serial.serialize_field(key, &(id.expect("no id")).id())?;
113 };
114
115 let ip = self.ip();
116 if ip.is_ipv4() {
117 serial.serialize_field("ip_v4", &ip.to_string())?;
118 } else {
119 serial.serialize_field("ip_v6", &ip.to_string())?;
120 }
121 serial.serialize_field("port", &self.port())?;
122 serial.end()
123 }
124}
125
126impl<'de, X: SerializableId> Deserialize<'de> for SocketInfo<X> {
127 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128 where
129 D: Deserializer<'de>,
130 {
131 use std::marker::PhantomData;
132 enum Field {
133 IP,
134 PORT,
135 ID,
136 }
137
138 impl<'de> Deserialize<'de> for Field {
139 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
140 where
141 D: Deserializer<'de>,
142 {
143 struct FieldVisitor;
144
145 impl<'de> Visitor<'de> for FieldVisitor {
146 type Value = Field;
147
148 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
149 formatter.write_str("`ip_v4` or `ip_v6` or `port` or `*_id`")
150 }
151
152 fn visit_str<E>(self, value: &str) -> Result<Field, E>
153 where
154 E: de::Error,
155 {
156 match value {
157 "ip_v4" => Ok(Field::IP),
158 "ip_v6" => Ok(Field::IP),
159 "port" => Ok(Field::PORT),
160 id if id.ends_with("_id") => Ok(Field::ID),
161 _ => Err(de::Error::unknown_field(value, FIELDS)),
162 }
163 }
164 }
165
166 deserializer.deserialize_identifier(FieldVisitor)
167 }
168 }
169
170 struct SocketInfoVisitor<T>(PhantomData<T>);
171
172 impl<'de, T: SerializableId> Visitor<'de> for SocketInfoVisitor<T> {
173 type Value = SocketInfo<T>;
174
175 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176 formatter.write_str("struct SocketAddr")
177 }
178
179 fn visit_map<V>(self, mut map: V) -> Result<SocketInfo<T>, V::Error>
180 where
181 V: MapAccess<'de>,
182 {
183 let mut ip: Option<String> = None;
184 let mut id: Option<String> = None;
185 let mut port: Option<u16> = None;
186 while let Some(key) = map.next_key()? {
187 match key {
188 Field::PORT => {
189 if port.is_some() {
190 return Err(de::Error::duplicate_field("port"));
191 }
192 port = Some(map.next_value()?);
193 }
194 Field::IP => {
195 if ip.is_some() {
196 return Err(de::Error::duplicate_field("ip_v4"));
197 }
198 ip = Some(map.next_value()?);
199 }
200 Field::ID => {
201 if id.is_some() {
202 return Err(de::Error::duplicate_field("id"));
203 }
204 id = Some(map.next_value()?);
205 }
206 }
207 }
208 let ip = ip.ok_or_else(|| de::Error::missing_field("ip_v4 or ip_v6"))?;
209 let port = port.ok_or_else(|| de::Error::missing_field("port"))?;
210 let socket_info = SocketInfo::<T>::try_create(id, &ip, port);
211 if let Err(_err) = socket_info {
212 use serde::de::Error;
213 return Err(V::Error::custom(format!("fail to deserialize socket")));
214 }
215
216 Ok(socket_info.unwrap())
217 }
218 }
219
220 const FIELDS: &'static [&'static str] = &["ip_v4", "ip_v6", "port", "*_id"];
221 deserializer.deserialize_struct("SocketAddr", FIELDS, SocketInfoVisitor(PhantomData))
222 }
223}
224
225#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
231pub struct PhantomId(String);
232
233impl SerializableId for PhantomId {
234 fn try_create(id: impl Into<String>) -> Result<Self, error::Error>
235 where
236 Self: Sized,
237 {
238 Ok(PhantomId(id.into()))
239 }
240
241 fn as_str(&self) -> &str {
242 ""
243 }
244
245 fn id(&self) -> String {
246 String::from("")
247 }
248
249 fn key(&self) -> &'static str {
250 ""
251 }
252}
253
254#[cfg(test)]
255mod test_socket_info {
256 use super::*;
257
258 #[test]
259 fn v4() {
260 let socket_info = SocketInfo::<PhantomId>::try_create(None, "127.0.0.1", 8000).unwrap();
261 let json = serde_json::to_string(&socket_info).expect("serialize failed");
262 let decoded_socket_info: SocketInfo<PhantomId> =
263 serde_json::from_str(&json).expect("deserialize failed");
264 assert_eq!(socket_info, decoded_socket_info);
265 }
266
267 #[test]
268 fn v6() {
269 let socket_info =
270 SocketInfo::<PhantomId>::try_create(None, "2001:DB8:0:0:8:800:200C:417A", 8000)
271 .unwrap();
272 let json = serde_json::to_string(&socket_info).expect("serialize failed");
273 let decoded_socket_info: SocketInfo<PhantomId> =
274 serde_json::from_str(&json).expect("deserialize failed");
275 assert_eq!(socket_info, decoded_socket_info);
276 }
277}