zenoh_plugin_dds/
config.rs1use std::{env, fmt, time::Duration};
15
16use regex::Regex;
17use serde::{de, de::Visitor, Deserialize, Deserializer};
18use zenoh::key_expr::OwnedKeyExpr;
19
20pub const DEFAULT_DOMAIN: u32 = 0;
21pub const DEFAULT_FORWARD_DISCOVERY: bool = false;
22pub const DEFAULT_RELIABLE_ROUTES_BLOCKING: bool = true;
23pub const DEFAULT_QUERIES_TIMEOUT: f32 = 5.0;
24pub const DEFAULT_DDS_LOCALHOST_ONLY: bool = false;
25pub const DEFAULT_WORK_THREAD_NUM: usize = 2;
26pub const DEFAULT_MAX_BLOCK_THREAD_NUM: usize = 50;
27
28#[derive(Deserialize, Debug)]
29#[serde(deny_unknown_fields)]
30pub struct Config {
31 #[serde(default)]
32 pub scope: Option<OwnedKeyExpr>,
33 #[serde(default = "default_domain")]
34 pub domain: u32,
35 #[serde(default, deserialize_with = "deserialize_regex")]
36 pub allow: Option<Regex>,
37 #[serde(default, deserialize_with = "deserialize_regex")]
38 pub deny: Option<Regex>,
39 #[serde(default, deserialize_with = "deserialize_max_frequencies")]
40 pub max_frequencies: Vec<(Regex, f32)>,
41 #[serde(default)]
42 pub generalise_subs: Vec<OwnedKeyExpr>,
43 #[serde(default)]
44 pub generalise_pubs: Vec<OwnedKeyExpr>,
45 #[serde(default = "default_forward_discovery")]
46 pub forward_discovery: bool,
47 #[serde(default = "default_reliable_routes_blocking")]
48 pub reliable_routes_blocking: bool,
49 #[serde(default = "default_localhost_only")]
50 pub localhost_only: bool,
51 #[serde(default)]
52 #[cfg(feature = "dds_shm")]
53 pub shm_enabled: bool,
54 #[serde(
55 default = "default_queries_timeout",
56 deserialize_with = "deserialize_duration"
57 )]
58 pub queries_timeout: Duration,
59 #[serde(default = "default_work_thread_num")]
60 pub work_thread_num: usize,
61 #[serde(default = "default_max_block_thread_num")]
62 pub max_block_thread_num: usize,
63 __required__: Option<bool>,
64 #[serde(default, deserialize_with = "deserialize_path")]
65 __path__: Option<Vec<String>>,
66}
67
68fn default_domain() -> u32 {
69 if let Ok(s) = env::var("ROS_DOMAIN_ID") {
70 s.parse::<u32>().unwrap_or(DEFAULT_DOMAIN)
71 } else {
72 DEFAULT_DOMAIN
73 }
74}
75
76fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
77where
78 D: Deserializer<'de>,
79{
80 deserializer.deserialize_option(OptPathVisitor)
81}
82
83struct OptPathVisitor;
84
85impl<'de> serde::de::Visitor<'de> for OptPathVisitor {
86 type Value = Option<Vec<String>>;
87
88 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
89 write!(formatter, "none or a string or an array of strings")
90 }
91
92 fn visit_none<E>(self) -> Result<Self::Value, E>
93 where
94 E: de::Error,
95 {
96 Ok(None)
97 }
98
99 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
100 where
101 D: Deserializer<'de>,
102 {
103 deserializer.deserialize_any(PathVisitor).map(Some)
104 }
105}
106
107struct PathVisitor;
108
109impl<'de> serde::de::Visitor<'de> for PathVisitor {
110 type Value = Vec<String>;
111
112 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
113 write!(formatter, "a string or an array of strings")
114 }
115
116 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
117 where
118 E: de::Error,
119 {
120 Ok(vec![v.into()])
121 }
122
123 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
124 where
125 A: de::SeqAccess<'de>,
126 {
127 let mut v = if let Some(l) = seq.size_hint() {
128 Vec::with_capacity(l)
129 } else {
130 Vec::new()
131 };
132 while let Some(s) = seq.next_element()? {
133 v.push(s);
134 }
135 Ok(v)
136 }
137}
138
139fn deserialize_regex<'de, D>(deserializer: D) -> Result<Option<Regex>, D::Error>
140where
141 D: Deserializer<'de>,
142{
143 deserializer.deserialize_any(RegexVisitor)
144}
145
146fn deserialize_max_frequencies<'de, D>(deserializer: D) -> Result<Vec<(Regex, f32)>, D::Error>
147where
148 D: Deserializer<'de>,
149{
150 let strs: Vec<String> = Deserialize::deserialize(deserializer)?;
151 let mut result: Vec<(Regex, f32)> = Vec::with_capacity(strs.len());
152 for s in strs {
153 let i = s
154 .find('=')
155 .ok_or_else(|| de::Error::custom(format!("Invalid 'max_frequency': {s}")))?;
156 let regex = Regex::new(&s[0..i]).map_err(|e| {
157 de::Error::custom(format!("Invalid regex for 'max_frequency': '{s}': {e}"))
158 })?;
159 let frequency: f32 = s[i + 1..].parse().map_err(|e| {
160 de::Error::custom(format!(
161 "Invalid float value for 'max_frequency': '{s}': {e}"
162 ))
163 })?;
164 result.push((regex, frequency));
165 }
166 Ok(result)
167}
168
169fn default_queries_timeout() -> Duration {
170 Duration::from_secs_f32(DEFAULT_QUERIES_TIMEOUT)
171}
172
173fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
174where
175 D: Deserializer<'de>,
176{
177 let seconds: f32 = Deserialize::deserialize(deserializer)?;
178 Ok(Duration::from_secs_f32(seconds))
179}
180
181fn default_work_thread_num() -> usize {
182 DEFAULT_WORK_THREAD_NUM
183}
184
185fn default_max_block_thread_num() -> usize {
186 DEFAULT_MAX_BLOCK_THREAD_NUM
187}
188
189fn default_forward_discovery() -> bool {
190 DEFAULT_FORWARD_DISCOVERY
191}
192
193fn default_reliable_routes_blocking() -> bool {
194 DEFAULT_RELIABLE_ROUTES_BLOCKING
195}
196
197fn default_localhost_only() -> bool {
198 env::var("ROS_LOCALHOST_ONLY").as_deref() == Ok("1")
199}
200
201struct RegexVisitor;
204
205impl<'de> Visitor<'de> for RegexVisitor {
206 type Value = Option<Regex>;
207
208 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
209 formatter.write_str(r#"either a string or a list of strings"#)
210 }
211
212 fn visit_unit<E>(self) -> Result<Self::Value, E>
214 where
215 E: de::Error,
216 {
217 Ok(None)
218 }
219
220 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
221 where
222 E: de::Error,
223 {
224 Regex::new(value)
225 .map(Some)
226 .map_err(|e| de::Error::custom(format!("Invalid regex '{value}': {e}")))
227 }
228
229 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
230 where
231 A: de::SeqAccess<'de>,
232 {
233 let mut vec: Vec<String> = Vec::new();
234 while let Some(s) = seq.next_element()? {
235 vec.push(s);
236 }
237 let s: String = vec.join("|");
238 Regex::new(&s)
239 .map(Some)
240 .map_err(|e| de::Error::custom(format!("Invalid regex '{s}': {e}")))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::Config;
247
248 #[test]
249 fn test_path_field() {
250 let config = serde_json::from_str::<Config>(r#"{"__path__": "/example/path"}"#);
252
253 assert!(config.is_ok());
254 let Config {
255 __required__,
256 __path__,
257 ..
258 } = config.unwrap();
259
260 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
261 assert_eq!(__required__, None);
262 }
263
264 #[test]
265 fn test_required_field() {
266 let config = serde_json::from_str::<Config>(r#"{"__required__": true}"#);
268 assert!(config.is_ok());
269 let Config {
270 __required__,
271 __path__,
272 ..
273 } = config.unwrap();
274
275 assert_eq!(__path__, None);
276 assert_eq!(__required__, Some(true));
277 }
278
279 #[test]
280 fn test_path_field_and_required_field() {
281 let config = serde_json::from_str::<Config>(
283 r#"{"__path__": "/example/path", "__required__": true}"#,
284 );
285
286 assert!(config.is_ok());
287 let Config {
288 __required__,
289 __path__,
290 ..
291 } = config.unwrap();
292
293 assert_eq!(__path__, Some(vec![String::from("/example/path")]));
294 assert_eq!(__required__, Some(true));
295 }
296
297 #[test]
298 fn test_no_path_field_and_no_required_field() {
299 let config = serde_json::from_str::<Config>("{}");
301
302 assert!(config.is_ok());
303 let Config {
304 __required__,
305 __path__,
306 ..
307 } = config.unwrap();
308
309 assert_eq!(__path__, None);
310 assert_eq!(__required__, None);
311 }
312}