1use std::fmt;
15
16use schemars::JsonSchema;
17use serde::{
18 de,
19 de::{Unexpected, Visitor},
20 Deserialize, Deserializer,
21};
22
23const DEFAULT_HTTP_INTERFACE: &str = "[::]";
24
25#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
26#[serde(deny_unknown_fields)]
27pub struct Config {
28 #[serde(deserialize_with = "deserialize_ws_port")]
29 pub websocket_port: String,
30
31 pub secure_websocket: Option<SecureWebsocket>,
32
33 #[serde(default, deserialize_with = "deserialize_path")]
34 __path__: Option<Vec<String>>,
35 __required__: Option<bool>,
36 __config__: Option<String>,
37}
38
39#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
40#[serde(deny_unknown_fields)]
41pub struct SecureWebsocket {
42 pub certificate_path: String,
43 pub private_key_path: String,
44}
45
46impl From<&Config> for serde_json::Value {
47 fn from(c: &Config) -> Self {
48 serde_json::to_value(c).unwrap()
49 }
50}
51
52fn deserialize_ws_port<'de, D>(deserializer: D) -> Result<String, D::Error>
53where
54 D: Deserializer<'de>,
55{
56 deserializer.deserialize_any(WebsocketVisitor)
57}
58
59struct WebsocketVisitor;
60
61impl Visitor<'_> for WebsocketVisitor {
62 type Value = String;
63
64 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
65 formatter.write_str(r#"either a port number as an integer or a string, either a string with format "<local_ip>:<port_number>""#)
66 }
67
68 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
69 where
70 E: de::Error,
71 {
72 Ok(format!("{DEFAULT_HTTP_INTERFACE}:{value}"))
73 }
74
75 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
76 where
77 E: de::Error,
78 {
79 let parts: Vec<&str> = value.split(':').collect();
80 if parts.len() > 2 {
81 return Err(E::invalid_value(Unexpected::Str(value), &self));
82 }
83 let (interface, port) = if parts.len() == 1 {
84 (DEFAULT_HTTP_INTERFACE, parts[0])
85 } else {
86 (parts[0], parts[1])
87 };
88 if port.parse::<u32>().is_err() {
89 return Err(E::invalid_value(Unexpected::Str(port), &self));
90 }
91 Ok(format!("{interface}:{port}"))
92 }
93}
94
95fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
96where
97 D: Deserializer<'de>,
98{
99 deserializer.deserialize_option(OptPathVisitor)
100}
101
102struct OptPathVisitor;
103
104impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
105 type Value = Option<Vec<String>>;
106
107 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
108 write!(formatter, "none or a string or an array of strings")
109 }
110
111 fn visit_none<E>(self) -> Result<Self::Value, E>
112 where
113 E: de::Error,
114 {
115 Ok(None)
116 }
117
118 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
119 where
120 D: Deserializer<'de>,
121 {
122 deserializer.deserialize_any(PathVisitor).map(Some)
123 }
124}
125
126struct PathVisitor;
127
128impl<'de> serde::de::Visitor<'de> for PathVisitor {
129 type Value = Vec<String>;
130
131 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
132 write!(formatter, "a string or an array of strings")
133 }
134
135 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
136 where
137 E: de::Error,
138 {
139 Ok(vec![v.into()])
140 }
141
142 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
143 where
144 A: de::SeqAccess<'de>,
145 {
146 let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
147
148 while let Some(s) = seq.next_element()? {
149 v.push(s);
150 }
151 Ok(v)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::{Config, DEFAULT_HTTP_INTERFACE};
158
159 #[test]
160 fn test_path_field() {
161 let config = serde_json::from_str::<Config>(
163 r#"{"__path__": "/example/path", "websocket_port": 8080}"#,
164 );
165
166 assert!(config.is_ok());
167 let Config {
168 websocket_port,
169 __required__,
170 __path__,
171 ..
172 } = config.unwrap();
173
174 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
175 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
176 assert_eq!(__required__, None);
177 }
178
179 #[test]
180 fn test_required_field() {
181 let config =
183 serde_json::from_str::<Config>(r#"{"__required__": true, "websocket_port": 8080}"#);
184 assert!(config.is_ok());
185 let Config {
186 websocket_port,
187 __required__,
188 __path__,
189 ..
190 } = config.unwrap();
191
192 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
193 assert_eq!(__path__, None);
194 assert_eq!(__required__, Some(true));
195 }
196
197 #[test]
198 fn test_path_field_and_required_field() {
199 let config = serde_json::from_str::<Config>(
201 r#"{"__path__": "/example/path", "__required__": true, "websocket_port": 8080}"#,
202 );
203
204 assert!(config.is_ok());
205 let Config {
206 websocket_port,
207 __required__,
208 __path__,
209 ..
210 } = config.unwrap();
211
212 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
213 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
214 assert_eq!(__required__, Some(true));
215 }
216
217 #[test]
218 fn test_no_path_field_and_no_required_field() {
219 let config = serde_json::from_str::<Config>(r#"{"websocket_port": 8080}"#);
221
222 assert!(config.is_ok());
223 let Config {
224 websocket_port,
225 __required__,
226 __path__,
227 ..
228 } = config.unwrap();
229
230 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
231 assert_eq!(__path__, None);
232 assert_eq!(__required__, None);
233 }
234}