zenoh_plugin_rest/
config.rs1use std::fmt;
15
16use schemars::JsonSchema;
17use serde::{
18 de,
19 de::{Unexpected, Visitor},
20 Deserialize, Deserializer,
21};
22
23const DEFAULT_HTTP_INTERFACE: &str = "[::]";
24pub const DEFAULT_WORK_THREAD_NUM: usize = 2;
25pub const DEFAULT_MAX_BLOCK_THREAD_NUM: usize = 50;
26
27#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
28#[serde(deny_unknown_fields)]
29pub struct Config {
30 #[serde(deserialize_with = "deserialize_http_port")]
31 pub http_port: String,
32 #[serde(default = "default_work_thread_num")]
33 pub work_thread_num: usize,
34 #[serde(default = "default_max_block_thread_num")]
35 pub max_block_thread_num: usize,
36 #[serde(default, deserialize_with = "deserialize_path")]
37 __path__: Option<Vec<String>>,
38 __required__: Option<bool>,
39 __config__: Option<String>,
40 __plugin__: Option<String>,
41}
42
43impl From<&Config> for serde_json::Value {
44 fn from(c: &Config) -> Self {
45 serde_json::to_value(c).unwrap()
46 }
47}
48
49fn deserialize_http_port<'de, D>(deserializer: D) -> Result<String, D::Error>
50where
51 D: Deserializer<'de>,
52{
53 deserializer.deserialize_any(HttpPortVisitor)
54}
55
56fn default_work_thread_num() -> usize {
57 DEFAULT_WORK_THREAD_NUM
58}
59
60fn default_max_block_thread_num() -> usize {
61 DEFAULT_MAX_BLOCK_THREAD_NUM
62}
63
64struct HttpPortVisitor;
65
66impl Visitor<'_> for HttpPortVisitor {
67 type Value = String;
68
69 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
70 formatter.write_str(r#"either a port number as an integer or a string, either a string with format "<local_ip>:<port_number>""#)
71 }
72
73 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
74 where
75 E: de::Error,
76 {
77 Ok(format!("{DEFAULT_HTTP_INTERFACE}:{value}"))
78 }
79
80 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
81 where
82 E: de::Error,
83 {
84 let parts: Vec<&str> = value.split(':').collect();
85 if parts.len() > 2 {
86 return Err(E::invalid_value(Unexpected::Str(value), &self));
87 }
88 let (interface, port) = if parts.len() == 1 {
89 (DEFAULT_HTTP_INTERFACE, parts[0])
90 } else {
91 (parts[0], parts[1])
92 };
93 if port.parse::<u32>().is_err() {
94 return Err(E::invalid_value(Unexpected::Str(port), &self));
95 }
96 Ok(format!("{interface}:{port}"))
97 }
98}
99
100fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
101where
102 D: Deserializer<'de>,
103{
104 deserializer.deserialize_option(OptPathVisitor)
105}
106
107struct OptPathVisitor;
108
109impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
110 type Value = Option<Vec<String>>;
111
112 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
113 write!(formatter, "none or a string or an array of strings")
114 }
115
116 fn visit_none<E>(self) -> Result<Self::Value, E>
117 where
118 E: de::Error,
119 {
120 Ok(None)
121 }
122
123 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
124 where
125 D: Deserializer<'de>,
126 {
127 deserializer.deserialize_any(PathVisitor).map(Some)
128 }
129}
130
131struct PathVisitor;
132
133impl<'de> serde::de::Visitor<'de> for PathVisitor {
134 type Value = Vec<String>;
135
136 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
137 write!(formatter, "a string or an array of strings")
138 }
139
140 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
141 where
142 E: de::Error,
143 {
144 Ok(vec![v.into()])
145 }
146
147 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
148 where
149 A: de::SeqAccess<'de>,
150 {
151 let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
152
153 while let Some(s) = seq.next_element()? {
154 v.push(s);
155 }
156 Ok(v)
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::{Config, DEFAULT_HTTP_INTERFACE};
163
164 #[test]
165 fn test_path_field() {
166 let config =
168 serde_json::from_str::<Config>(r#"{"__path__": "/example/path", "http_port": 8080}"#);
169
170 assert!(config.is_ok());
171 let Config {
172 http_port,
173 __required__,
174 __path__,
175 ..
176 } = config.unwrap();
177
178 assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
179 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
180 assert_eq!(__required__, None);
181 }
182
183 #[test]
184 fn test_required_field() {
185 let config = serde_json::from_str::<Config>(r#"{"__required__": true, "http_port": 8080}"#);
187 assert!(config.is_ok());
188 let Config {
189 http_port,
190 __required__,
191 __path__,
192 ..
193 } = config.unwrap();
194
195 assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
196 assert_eq!(__path__, None);
197 assert_eq!(__required__, Some(true));
198 }
199
200 #[test]
201 fn test_path_field_and_required_field() {
202 let config = serde_json::from_str::<Config>(
204 r#"{"__path__": "/example/path", "__required__": true, "http_port": 8080}"#,
205 );
206
207 assert!(config.is_ok());
208 let Config {
209 http_port,
210 __required__,
211 __path__,
212 ..
213 } = config.unwrap();
214
215 assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
216 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
217 assert_eq!(__required__, Some(true));
218 }
219
220 #[test]
221 fn test_no_path_field_and_no_required_field() {
222 let config = serde_json::from_str::<Config>(r#"{"http_port": 8080}"#);
224
225 assert!(config.is_ok());
226 let Config {
227 http_port,
228 __required__,
229 __path__,
230 ..
231 } = config.unwrap();
232
233 assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
234 assert_eq!(__path__, None);
235 assert_eq!(__required__, None);
236 }
237}