1use std::fmt;
15
16use schemars::JsonSchema;
17use serde::{
18 de,
19 de::{Unexpected, Visitor},
20 Deserialize, Deserializer,
21};
22
23const DEFAULT_HTTP_INTERFACE: &str = "[::]";
24const DEFAULT_WEBSOCKET_PORT: &str = "10000";
25
26#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
27#[serde(deny_unknown_fields)]
28pub struct Config {
29 #[serde(
30 default = "default_websocket_port",
31 deserialize_with = "deserialize_ws_port"
32 )]
33 pub websocket_port: String,
34
35 pub secure_websocket: Option<SecureWebsocket>,
36
37 #[serde(default, deserialize_with = "deserialize_path")]
38 __path__: Option<Vec<String>>,
39 __required__: Option<bool>,
40 __config__: Option<String>,
41}
42
43#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
44#[serde(deny_unknown_fields)]
45pub struct SecureWebsocket {
46 pub certificate_path: String,
47 pub private_key_path: String,
48}
49
50impl From<&Config> for serde_json::Value {
51 fn from(c: &Config) -> Self {
52 serde_json::to_value(c).unwrap()
53 }
54}
55
56fn default_websocket_port() -> String {
57 format!("{}:{}", DEFAULT_HTTP_INTERFACE, DEFAULT_WEBSOCKET_PORT)
58}
59
60fn deserialize_ws_port<'de, D>(deserializer: D) -> Result<String, D::Error>
61where
62 D: Deserializer<'de>,
63{
64 deserializer.deserialize_any(WebsocketVisitor)
65}
66
67struct WebsocketVisitor;
68
69impl Visitor<'_> for WebsocketVisitor {
70 type Value = String;
71
72 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
73 formatter.write_str(r#"either a port number as an integer or a string, either a string with format "<local_ip>:<port_number>""#)
74 }
75
76 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
77 where
78 E: de::Error,
79 {
80 Ok(format!("{DEFAULT_HTTP_INTERFACE}:{value}"))
81 }
82
83 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
84 where
85 E: de::Error,
86 {
87 let parts: Vec<&str> = value.split(':').collect();
88 if parts.len() > 2 {
89 return Err(E::invalid_value(Unexpected::Str(value), &self));
90 }
91 let (interface, port) = if parts.len() == 1 {
92 (DEFAULT_HTTP_INTERFACE, parts[0])
93 } else {
94 (parts[0], parts[1])
95 };
96 if port.parse::<u32>().is_err() {
97 return Err(E::invalid_value(Unexpected::Str(port), &self));
98 }
99 Ok(format!("{interface}:{port}"))
100 }
101}
102
103fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
104where
105 D: Deserializer<'de>,
106{
107 deserializer.deserialize_option(OptPathVisitor)
108}
109
110struct OptPathVisitor;
111
112impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
113 type Value = Option<Vec<String>>;
114
115 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
116 write!(formatter, "none or a string or an array of strings")
117 }
118
119 fn visit_none<E>(self) -> Result<Self::Value, E>
120 where
121 E: de::Error,
122 {
123 Ok(None)
124 }
125
126 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
127 where
128 D: Deserializer<'de>,
129 {
130 deserializer.deserialize_any(PathVisitor).map(Some)
131 }
132}
133
134struct PathVisitor;
135
136impl<'de> serde::de::Visitor<'de> for PathVisitor {
137 type Value = Vec<String>;
138
139 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
140 write!(formatter, "a string or an array of strings")
141 }
142
143 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
144 where
145 E: de::Error,
146 {
147 Ok(vec![v.into()])
148 }
149
150 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
151 where
152 A: de::SeqAccess<'de>,
153 {
154 let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
155
156 while let Some(s) = seq.next_element()? {
157 v.push(s);
158 }
159 Ok(v)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::{Config, DEFAULT_HTTP_INTERFACE, DEFAULT_WEBSOCKET_PORT};
166
167 #[test]
168 fn test_path_field() {
169 let config = serde_json::from_str::<Config>(
171 r#"{"__path__": "/example/path", "websocket_port": 8080}"#,
172 );
173
174 assert!(config.is_ok());
175 let Config {
176 websocket_port,
177 __required__,
178 __path__,
179 ..
180 } = config.unwrap();
181
182 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
183 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
184 assert_eq!(__required__, None);
185 }
186
187 #[test]
188 fn test_required_field() {
189 let config =
191 serde_json::from_str::<Config>(r#"{"__required__": true, "websocket_port": 8080}"#);
192 assert!(config.is_ok());
193 let Config {
194 websocket_port,
195 __required__,
196 __path__,
197 ..
198 } = config.unwrap();
199
200 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
201 assert_eq!(__path__, None);
202 assert_eq!(__required__, Some(true));
203 }
204
205 #[test]
206 fn test_path_field_and_required_field() {
207 let config = serde_json::from_str::<Config>(
209 r#"{"__path__": "/example/path", "__required__": true, "websocket_port": 8080}"#,
210 );
211
212 assert!(config.is_ok());
213 let Config {
214 websocket_port,
215 __required__,
216 __path__,
217 ..
218 } = config.unwrap();
219
220 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
221 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
222 assert_eq!(__required__, Some(true));
223 }
224
225 #[test]
226 fn test_no_path_field_and_no_required_field() {
227 let config = serde_json::from_str::<Config>(r#"{"websocket_port": 8080}"#);
229
230 assert!(config.is_ok());
231 let Config {
232 websocket_port,
233 __required__,
234 __path__,
235 ..
236 } = config.unwrap();
237
238 assert_eq!(websocket_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
239 assert_eq!(__path__, None);
240 assert_eq!(__required__, None);
241 }
242
243 #[test]
244 fn test_default_websocket_port() {
245 let config = serde_json::from_str::<Config>(r#"{}"#);
247
248 assert!(config.is_ok());
249 let Config {
250 websocket_port,
251 __required__,
252 __path__,
253 ..
254 } = config.unwrap();
255
256 assert_eq!(
257 websocket_port,
258 format!("{DEFAULT_HTTP_INTERFACE}:{DEFAULT_WEBSOCKET_PORT}")
259 );
260 assert_eq!(__path__, None);
261 assert_eq!(__required__, None);
262 }
263}