zenoh_plugin_rest/
config.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use std::fmt;
15
16use schemars::JsonSchema;
17use serde::{
18    de,
19    de::{Unexpected, Visitor},
20    Deserialize, Deserializer,
21};
22
23const DEFAULT_HTTP_INTERFACE: &str = "[::]";
24pub const DEFAULT_WORK_THREAD_NUM: usize = 2;
25pub const DEFAULT_MAX_BLOCK_THREAD_NUM: usize = 50;
26
27#[derive(JsonSchema, Deserialize, serde::Serialize, Clone, Debug)]
28#[serde(deny_unknown_fields)]
29pub struct Config {
30    #[serde(deserialize_with = "deserialize_http_port")]
31    pub http_port: String,
32    #[serde(default = "default_work_thread_num")]
33    pub work_thread_num: usize,
34    #[serde(default = "default_max_block_thread_num")]
35    pub max_block_thread_num: usize,
36    #[serde(default, deserialize_with = "deserialize_path")]
37    __path__: Option<Vec<String>>,
38    __required__: Option<bool>,
39    __config__: Option<String>,
40    __plugin__: Option<String>,
41}
42
43impl From<&Config> for serde_json::Value {
44    fn from(c: &Config) -> Self {
45        serde_json::to_value(c).unwrap()
46    }
47}
48
49fn deserialize_http_port<'de, D>(deserializer: D) -> Result<String, D::Error>
50where
51    D: Deserializer<'de>,
52{
53    deserializer.deserialize_any(HttpPortVisitor)
54}
55
56fn default_work_thread_num() -> usize {
57    DEFAULT_WORK_THREAD_NUM
58}
59
60fn default_max_block_thread_num() -> usize {
61    DEFAULT_MAX_BLOCK_THREAD_NUM
62}
63
64struct HttpPortVisitor;
65
66impl Visitor<'_> for HttpPortVisitor {
67    type Value = String;
68
69    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
70        formatter.write_str(r#"either a port number as an integer or a string, either a string with format "<local_ip>:<port_number>""#)
71    }
72
73    fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
74    where
75        E: de::Error,
76    {
77        Ok(format!("{DEFAULT_HTTP_INTERFACE}:{value}"))
78    }
79
80    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
81    where
82        E: de::Error,
83    {
84        let parts: Vec<&str> = value.split(':').collect();
85        if parts.len() > 2 {
86            return Err(E::invalid_value(Unexpected::Str(value), &self));
87        }
88        let (interface, port) = if parts.len() == 1 {
89            (DEFAULT_HTTP_INTERFACE, parts[0])
90        } else {
91            (parts[0], parts[1])
92        };
93        if port.parse::<u32>().is_err() {
94            return Err(E::invalid_value(Unexpected::Str(port), &self));
95        }
96        Ok(format!("{interface}:{port}"))
97    }
98}
99
100fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
101where
102    D: Deserializer<'de>,
103{
104    deserializer.deserialize_option(OptPathVisitor)
105}
106
107struct OptPathVisitor;
108
109impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
110    type Value = Option<Vec<String>>;
111
112    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
113        write!(formatter, "none or a string or an array of strings")
114    }
115
116    fn visit_none<E>(self) -> Result<Self::Value, E>
117    where
118        E: de::Error,
119    {
120        Ok(None)
121    }
122
123    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
124    where
125        D: Deserializer<'de>,
126    {
127        deserializer.deserialize_any(PathVisitor).map(Some)
128    }
129}
130
131struct PathVisitor;
132
133impl<'de> serde::de::Visitor<'de> for PathVisitor {
134    type Value = Vec<String>;
135
136    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
137        write!(formatter, "a string or an array of strings")
138    }
139
140    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
141    where
142        E: de::Error,
143    {
144        Ok(vec![v.into()])
145    }
146
147    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
148    where
149        A: de::SeqAccess<'de>,
150    {
151        let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity);
152
153        while let Some(s) = seq.next_element()? {
154            v.push(s);
155        }
156        Ok(v)
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::{Config, DEFAULT_HTTP_INTERFACE};
163
164    #[test]
165    fn test_path_field() {
166        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
167        let config =
168            serde_json::from_str::<Config>(r#"{"__path__": "/example/path", "http_port": 8080}"#);
169
170        assert!(config.is_ok());
171        let Config {
172            http_port,
173            __required__,
174            __path__,
175            ..
176        } = config.unwrap();
177
178        assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
179        assert_eq!(__path__, Some(vec![String::from("/example/path")]));
180        assert_eq!(__required__, None);
181    }
182
183    #[test]
184    fn test_required_field() {
185        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
186        let config = serde_json::from_str::<Config>(r#"{"__required__": true, "http_port": 8080}"#);
187        assert!(config.is_ok());
188        let Config {
189            http_port,
190            __required__,
191            __path__,
192            ..
193        } = config.unwrap();
194
195        assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
196        assert_eq!(__path__, None);
197        assert_eq!(__required__, Some(true));
198    }
199
200    #[test]
201    fn test_path_field_and_required_field() {
202        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
203        let config = serde_json::from_str::<Config>(
204            r#"{"__path__": "/example/path", "__required__": true, "http_port": 8080}"#,
205        );
206
207        assert!(config.is_ok());
208        let Config {
209            http_port,
210            __required__,
211            __path__,
212            ..
213        } = config.unwrap();
214
215        assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
216        assert_eq!(__path__, Some(vec![String::from("/example/path")]));
217        assert_eq!(__required__, Some(true));
218    }
219
220    #[test]
221    fn test_no_path_field_and_no_required_field() {
222        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
223        let config = serde_json::from_str::<Config>(r#"{"http_port": 8080}"#);
224
225        assert!(config.is_ok());
226        let Config {
227            http_port,
228            __required__,
229            __path__,
230            ..
231        } = config.unwrap();
232
233        assert_eq!(http_port, format!("{DEFAULT_HTTP_INTERFACE}:8080"));
234        assert_eq!(__path__, None);
235        assert_eq!(__required__, None);
236    }
237}