zenoh_plugin_dds/
config.rs

1//
2// Copyright (c) 2022 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use std::{env, fmt, time::Duration};
15
16use regex::Regex;
17use serde::{de, de::Visitor, Deserialize, Deserializer};
18use zenoh::key_expr::OwnedKeyExpr;
19
20pub const DEFAULT_DOMAIN: u32 = 0;
21pub const DEFAULT_FORWARD_DISCOVERY: bool = false;
22pub const DEFAULT_RELIABLE_ROUTES_BLOCKING: bool = true;
23pub const DEFAULT_QUERIES_TIMEOUT: f32 = 5.0;
24pub const DEFAULT_DDS_LOCALHOST_ONLY: bool = false;
25pub const DEFAULT_WORK_THREAD_NUM: usize = 2;
26pub const DEFAULT_MAX_BLOCK_THREAD_NUM: usize = 50;
27
28#[derive(Deserialize, Debug)]
29#[serde(deny_unknown_fields)]
30pub struct Config {
31    #[serde(default)]
32    pub scope: Option<OwnedKeyExpr>,
33    #[serde(default = "default_domain")]
34    pub domain: u32,
35    #[serde(default, deserialize_with = "deserialize_regex")]
36    pub allow: Option<Regex>,
37    #[serde(default, deserialize_with = "deserialize_regex")]
38    pub deny: Option<Regex>,
39    #[serde(default, deserialize_with = "deserialize_max_frequencies")]
40    pub max_frequencies: Vec<(Regex, f32)>,
41    #[serde(default)]
42    pub generalise_subs: Vec<OwnedKeyExpr>,
43    #[serde(default)]
44    pub generalise_pubs: Vec<OwnedKeyExpr>,
45    #[serde(default = "default_forward_discovery")]
46    pub forward_discovery: bool,
47    #[serde(default = "default_reliable_routes_blocking")]
48    pub reliable_routes_blocking: bool,
49    #[serde(default = "default_localhost_only")]
50    pub localhost_only: bool,
51    #[serde(default)]
52    #[cfg(feature = "dds_shm")]
53    pub shm_enabled: bool,
54    #[serde(
55        default = "default_queries_timeout",
56        deserialize_with = "deserialize_duration"
57    )]
58    pub queries_timeout: Duration,
59    #[serde(default = "default_work_thread_num")]
60    pub work_thread_num: usize,
61    #[serde(default = "default_max_block_thread_num")]
62    pub max_block_thread_num: usize,
63    __required__: Option<bool>,
64    #[serde(default, deserialize_with = "deserialize_path")]
65    __path__: Option<Vec<String>>,
66}
67
68fn default_domain() -> u32 {
69    if let Ok(s) = env::var("ROS_DOMAIN_ID") {
70        s.parse::<u32>().unwrap_or(DEFAULT_DOMAIN)
71    } else {
72        DEFAULT_DOMAIN
73    }
74}
75
76fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
77where
78    D: Deserializer<'de>,
79{
80    deserializer.deserialize_option(OptPathVisitor)
81}
82
83struct OptPathVisitor;
84
85impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
86    type Value = Option<Vec<String>>;
87
88    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89        write!(formatter, "none or a string or an array of strings")
90    }
91
92    fn visit_none<E>(self) -> Result<Self::Value, E>
93    where
94        E: de::Error,
95    {
96        Ok(None)
97    }
98
99    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
100    where
101        D: Deserializer<'de>,
102    {
103        deserializer.deserialize_any(PathVisitor).map(Some)
104    }
105}
106
107struct PathVisitor;
108
109impl<'de> serde::de::Visitor<'de> for PathVisitor {
110    type Value = Vec<String>;
111
112    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
113        write!(formatter, "a string or an array of strings")
114    }
115
116    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
117    where
118        E: de::Error,
119    {
120        Ok(vec![v.into()])
121    }
122
123    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
124    where
125        A: de::SeqAccess<'de>,
126    {
127        let mut v = if let Some(l) = seq.size_hint() {
128            Vec::with_capacity(l)
129        } else {
130            Vec::new()
131        };
132        while let Some(s) = seq.next_element()? {
133            v.push(s);
134        }
135        Ok(v)
136    }
137}
138
139fn deserialize_regex<'de, D>(deserializer: D) -> Result<Option<Regex>, D::Error>
140where
141    D: Deserializer<'de>,
142{
143    deserializer.deserialize_any(RegexVisitor)
144}
145
146fn deserialize_max_frequencies<'de, D>(deserializer: D) -> Result<Vec<(Regex, f32)>, D::Error>
147where
148    D: Deserializer<'de>,
149{
150    let strs: Vec<String> = Deserialize::deserialize(deserializer)?;
151    let mut result: Vec<(Regex, f32)> = Vec::with_capacity(strs.len());
152    for s in strs {
153        let i = s
154            .find('=')
155            .ok_or_else(|| de::Error::custom(format!("Invalid 'max_frequency': {s}")))?;
156        let regex = Regex::new(&s[0..i]).map_err(|e| {
157            de::Error::custom(format!("Invalid regex for 'max_frequency': '{s}': {e}"))
158        })?;
159        let frequency: f32 = s[i + 1..].parse().map_err(|e| {
160            de::Error::custom(format!(
161                "Invalid float value for 'max_frequency': '{s}': {e}"
162            ))
163        })?;
164        result.push((regex, frequency));
165    }
166    Ok(result)
167}
168
169fn default_queries_timeout() -> Duration {
170    Duration::from_secs_f32(DEFAULT_QUERIES_TIMEOUT)
171}
172
173fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
174where
175    D: Deserializer<'de>,
176{
177    let seconds: f32 = Deserialize::deserialize(deserializer)?;
178    Ok(Duration::from_secs_f32(seconds))
179}
180
181fn default_work_thread_num() -> usize {
182    DEFAULT_WORK_THREAD_NUM
183}
184
185fn default_max_block_thread_num() -> usize {
186    DEFAULT_MAX_BLOCK_THREAD_NUM
187}
188
189fn default_forward_discovery() -> bool {
190    DEFAULT_FORWARD_DISCOVERY
191}
192
193fn default_reliable_routes_blocking() -> bool {
194    DEFAULT_RELIABLE_ROUTES_BLOCKING
195}
196
197fn default_localhost_only() -> bool {
198    env::var("ROS_LOCALHOST_ONLY").as_deref() == Ok("1")
199}
200
201// Serde Visitor for Regex deserialization.
202// It accepts either a String, either a list of Strings (that are concatenated with `|`)
203struct RegexVisitor;
204
205impl<'de> Visitor<'de> for RegexVisitor {
206    type Value = Option<Regex>;
207
208    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209        formatter.write_str(r#"either a string or a list of strings"#)
210    }
211
212    // for `null` value
213    fn visit_unit<E>(self) -> Result<Self::Value, E>
214    where
215        E: de::Error,
216    {
217        Ok(None)
218    }
219
220    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
221    where
222        E: de::Error,
223    {
224        Regex::new(value)
225            .map(Some)
226            .map_err(|e| de::Error::custom(format!("Invalid regex '{value}': {e}")))
227    }
228
229    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
230    where
231        A: de::SeqAccess<'de>,
232    {
233        let mut vec: Vec<String> = Vec::new();
234        while let Some(s) = seq.next_element()? {
235            vec.push(s);
236        }
237        let s: String = vec.join("|");
238        Regex::new(&s)
239            .map(Some)
240            .map_err(|e| de::Error::custom(format!("Invalid regex '{s}': {e}")))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::Config;
247
248    #[test]
249    fn test_path_field() {
250        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
251        let config = serde_json::from_str::<Config>(r#"{"__path__": "/example/path"}"#);
252
253        assert!(config.is_ok());
254        let Config {
255            __required__,
256            __path__,
257            ..
258        } = config.unwrap();
259
260        assert_eq!(__path__, Some(vec![String::from("/example/path")]));
261        assert_eq!(__required__, None);
262    }
263
264    #[test]
265    fn test_required_field() {
266        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
267        let config = serde_json::from_str::<Config>(r#"{"__required__": true}"#);
268        assert!(config.is_ok());
269        let Config {
270            __required__,
271            __path__,
272            ..
273        } = config.unwrap();
274
275        assert_eq!(__path__, None);
276        assert_eq!(__required__, Some(true));
277    }
278
279    #[test]
280    fn test_path_field_and_required_field() {
281        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
282        let config = serde_json::from_str::<Config>(
283            r#"{"__path__": "/example/path", "__required__": true}"#,
284        );
285
286        assert!(config.is_ok());
287        let Config {
288            __required__,
289            __path__,
290            ..
291        } = config.unwrap();
292
293        assert_eq!(__path__, Some(vec![String::from("/example/path")]));
294        assert_eq!(__required__, Some(true));
295    }
296
297    #[test]
298    fn test_no_path_field_and_no_required_field() {
299        // See: https://github.com/eclipse-zenoh/zenoh-plugin-webserver/issues/19
300        let config = serde_json::from_str::<Config>("{}");
301
302        assert!(config.is_ok());
303        let Config {
304            __required__,
305            __path__,
306            ..
307        } = config.unwrap();
308
309        assert_eq!(__path__, None);
310        assert_eq!(__required__, None);
311    }
312}