zenoh_protocol/core/
whatami.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use alloc::string::String;
15use core::{convert::TryFrom, fmt, num::NonZeroU8, ops::BitOr, str::FromStr};
16
17use const_format::formatcp;
18use serde::ser::SerializeSeq;
19use zenoh_result::{bail, ZError};
20
21#[repr(u8)]
22#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum WhatAmI {
24    Router = 0b001,
25    #[default]
26    Peer = 0b010,
27    Client = 0b100,
28}
29
30impl WhatAmI {
31    const STR_R: &'static str = "router";
32    const STR_P: &'static str = "peer";
33    const STR_C: &'static str = "client";
34
35    const U8_R: u8 = Self::Router as u8;
36    const U8_P: u8 = Self::Peer as u8;
37    const U8_C: u8 = Self::Client as u8;
38
39    pub const fn to_str(self) -> &'static str {
40        match self {
41            Self::Router => Self::STR_R,
42            Self::Peer => Self::STR_P,
43            Self::Client => Self::STR_C,
44        }
45    }
46
47    #[cfg(feature = "test")]
48    pub fn rand() -> Self {
49        use rand::prelude::SliceRandom;
50        let mut rng = rand::thread_rng();
51
52        *[Self::Router, Self::Peer, Self::Client]
53            .choose(&mut rng)
54            .unwrap()
55    }
56}
57
58impl TryFrom<u8> for WhatAmI {
59    type Error = ();
60
61    fn try_from(v: u8) -> Result<Self, Self::Error> {
62        match v {
63            Self::U8_R => Ok(Self::Router),
64            Self::U8_P => Ok(Self::Peer),
65            Self::U8_C => Ok(Self::Client),
66            _ => Err(()),
67        }
68    }
69}
70
71impl FromStr for WhatAmI {
72    type Err = ZError;
73
74    fn from_str(s: &str) -> Result<Self, Self::Err> {
75        match s {
76            Self::STR_R => Ok(Self::Router),
77            Self::STR_P => Ok(Self::Peer),
78            Self::STR_C => Ok(Self::Client),
79            _ => bail!(
80                "{s} is not a valid WhatAmI value. Valid values are: {}, {}, {}.",
81                Self::STR_R,
82                Self::STR_P,
83                Self::STR_C
84            ),
85        }
86    }
87}
88
89impl fmt::Display for WhatAmI {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.write_str(self.to_str())
92    }
93}
94
95impl From<WhatAmI> for u8 {
96    fn from(w: WhatAmI) -> Self {
97        w as u8
98    }
99}
100
101#[repr(transparent)]
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub struct WhatAmIMatcher(NonZeroU8);
104
105impl WhatAmIMatcher {
106    // We use the 7th bit for detecting whether the WhatAmIMatcher is non-zero
107    const U8_0: u8 = 1 << 7;
108    const U8_R: u8 = Self::U8_0 | WhatAmI::U8_R;
109    const U8_P: u8 = Self::U8_0 | WhatAmI::U8_P;
110    const U8_C: u8 = Self::U8_0 | WhatAmI::U8_C;
111    const U8_R_P: u8 = Self::U8_0 | WhatAmI::U8_R | WhatAmI::U8_P;
112    const U8_P_C: u8 = Self::U8_0 | WhatAmI::U8_P | WhatAmI::U8_C;
113    const U8_R_C: u8 = Self::U8_0 | WhatAmI::U8_R | WhatAmI::U8_C;
114    const U8_R_P_C: u8 = Self::U8_0 | WhatAmI::U8_R | WhatAmI::U8_P | WhatAmI::U8_C;
115
116    pub const fn empty() -> Self {
117        Self(unsafe { NonZeroU8::new_unchecked(Self::U8_0) })
118    }
119
120    pub const fn router(self) -> Self {
121        Self(unsafe { NonZeroU8::new_unchecked(self.0.get() | Self::U8_R) })
122    }
123
124    pub const fn peer(self) -> Self {
125        Self(unsafe { NonZeroU8::new_unchecked(self.0.get() | Self::U8_P) })
126    }
127
128    pub const fn client(self) -> Self {
129        Self(unsafe { NonZeroU8::new_unchecked(self.0.get() | Self::U8_C) })
130    }
131
132    pub const fn is_empty(&self) -> bool {
133        self.0.get() == Self::U8_0
134    }
135
136    pub const fn matches(&self, w: WhatAmI) -> bool {
137        (self.0.get() & w as u8) != 0
138    }
139
140    pub const fn to_str(self) -> &'static str {
141        match self.0.get() {
142            Self::U8_0 => "",
143            Self::U8_R => WhatAmI::STR_R,
144            Self::U8_P => WhatAmI::STR_P,
145            Self::U8_C => WhatAmI::STR_C,
146            Self::U8_R_P => formatcp!("{}|{}", WhatAmI::STR_R, WhatAmI::STR_P),
147            Self::U8_R_C => formatcp!("{}|{}", WhatAmI::STR_R, WhatAmI::STR_C),
148            Self::U8_P_C => formatcp!("{}|{}", WhatAmI::STR_P, WhatAmI::STR_C),
149            Self::U8_R_P_C => formatcp!("{}|{}|{}", WhatAmI::STR_R, WhatAmI::STR_P, WhatAmI::STR_C),
150
151            _ => unreachable!(),
152        }
153    }
154
155    #[cfg(feature = "test")]
156    pub fn rand() -> Self {
157        use rand::Rng;
158
159        let mut rng = rand::thread_rng();
160        let mut waim = WhatAmIMatcher::empty();
161        if rng.gen_bool(0.5) {
162            waim = waim.router();
163        }
164        if rng.gen_bool(0.5) {
165            waim = waim.peer();
166        }
167        if rng.gen_bool(0.5) {
168            waim = waim.client();
169        }
170        waim
171    }
172}
173
174impl TryFrom<u8> for WhatAmIMatcher {
175    type Error = ();
176
177    fn try_from(v: u8) -> Result<Self, Self::Error> {
178        const MIN: u8 = 0;
179        const MAX: u8 = WhatAmI::U8_R | WhatAmI::U8_P | WhatAmI::U8_C;
180
181        if (MIN..=MAX).contains(&v) {
182            Ok(WhatAmIMatcher(unsafe {
183                NonZeroU8::new_unchecked(Self::U8_0 | v)
184            }))
185        } else {
186            Err(())
187        }
188    }
189}
190
191impl FromStr for WhatAmIMatcher {
192    type Err = ();
193
194    fn from_str(s: &str) -> Result<Self, Self::Err> {
195        let mut inner = 0;
196        for s in s.split('|') {
197            match s.trim() {
198                "" => {}
199                WhatAmI::STR_R => inner |= WhatAmI::U8_R,
200                WhatAmI::STR_P => inner |= WhatAmI::U8_P,
201                WhatAmI::STR_C => inner |= WhatAmI::U8_C,
202                _ => return Err(()),
203            }
204        }
205        Self::try_from(inner)
206    }
207}
208
209impl fmt::Display for WhatAmIMatcher {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        f.write_str(self.to_str())
212    }
213}
214
215impl From<WhatAmIMatcher> for u8 {
216    fn from(w: WhatAmIMatcher) -> u8 {
217        w.0.get()
218    }
219}
220
221impl<T> BitOr<T> for WhatAmIMatcher
222where
223    NonZeroU8: BitOr<T, Output = NonZeroU8>,
224{
225    type Output = Self;
226
227    fn bitor(self, rhs: T) -> Self::Output {
228        WhatAmIMatcher(self.0 | rhs)
229    }
230}
231
232impl BitOr<WhatAmI> for WhatAmIMatcher {
233    type Output = Self;
234
235    fn bitor(self, rhs: WhatAmI) -> Self::Output {
236        self | rhs as u8
237    }
238}
239
240impl BitOr for WhatAmIMatcher {
241    type Output = Self;
242
243    fn bitor(self, rhs: Self) -> Self::Output {
244        self | rhs.0
245    }
246}
247
248impl BitOr for WhatAmI {
249    type Output = WhatAmIMatcher;
250
251    fn bitor(self, rhs: Self) -> Self::Output {
252        WhatAmIMatcher(unsafe {
253            NonZeroU8::new_unchecked(self as u8 | rhs as u8 | WhatAmIMatcher::U8_0)
254        })
255    }
256}
257
258impl From<WhatAmI> for WhatAmIMatcher {
259    fn from(w: WhatAmI) -> Self {
260        WhatAmIMatcher(unsafe { NonZeroU8::new_unchecked(w as u8 | WhatAmIMatcher::U8_0) })
261    }
262}
263
264// Serde
265impl serde::Serialize for WhatAmI {
266    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
267    where
268        S: serde::Serializer,
269    {
270        serializer.serialize_str(self.to_str())
271    }
272}
273
274pub struct WhatAmIVisitor;
275
276impl<'de> serde::de::Visitor<'de> for WhatAmIVisitor {
277    type Value = WhatAmI;
278
279    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
280        write!(
281            formatter,
282            "either '{}', '{}' or '{}'",
283            WhatAmI::STR_R,
284            WhatAmI::STR_P,
285            WhatAmI::STR_C
286        )
287    }
288    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
289    where
290        E: serde::de::Error,
291    {
292        v.parse().map_err(|_| {
293            serde::de::Error::unknown_variant(v, &[WhatAmI::STR_R, WhatAmI::STR_P, WhatAmI::STR_C])
294        })
295    }
296    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
297    where
298        E: serde::de::Error,
299    {
300        self.visit_str(v)
301    }
302    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
303    where
304        E: serde::de::Error,
305    {
306        self.visit_str(&v)
307    }
308}
309
310impl<'de> serde::Deserialize<'de> for WhatAmI {
311    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
312    where
313        D: serde::Deserializer<'de>,
314    {
315        deserializer.deserialize_str(WhatAmIVisitor)
316    }
317}
318
319impl serde::Serialize for WhatAmIMatcher {
320    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
321    where
322        S: serde::Serializer,
323    {
324        let values = [WhatAmI::Router, WhatAmI::Peer, WhatAmI::Client]
325            .iter()
326            .filter(|v| self.matches(**v));
327        let mut seq = serializer.serialize_seq(Some(values.clone().count()))?;
328        for v in values {
329            seq.serialize_element(v)?;
330        }
331        seq.end()
332    }
333}
334
335pub struct WhatAmIMatcherVisitor;
336impl<'de> serde::de::Visitor<'de> for WhatAmIMatcherVisitor {
337    type Value = WhatAmIMatcher;
338    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
339        write!(
340            formatter,
341            "a list of whatami variants ('{}', '{}', '{}')",
342            WhatAmI::STR_R,
343            WhatAmI::STR_P,
344            WhatAmI::STR_C
345        )
346    }
347
348    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
349    where
350        A: serde::de::SeqAccess<'de>,
351    {
352        let mut inner = 0;
353
354        while let Some(s) = seq.next_element::<String>()? {
355            match s.as_str() {
356                WhatAmI::STR_R => inner |= WhatAmI::U8_R,
357                WhatAmI::STR_P => inner |= WhatAmI::U8_P,
358                WhatAmI::STR_C => inner |= WhatAmI::U8_C,
359                _ => {
360                    return Err(serde::de::Error::invalid_value(
361                        serde::de::Unexpected::Str(&s),
362                        &formatcp!(
363                            "one of ('{}', '{}', '{}')",
364                            WhatAmI::STR_R,
365                            WhatAmI::STR_P,
366                            WhatAmI::STR_C
367                        ),
368                    ))
369                }
370            }
371        }
372
373        Ok(WhatAmIMatcher::try_from(inner)
374            .expect("`WhatAmIMatcher` should be valid by construction"))
375    }
376}
377
378impl<'de> serde::Deserialize<'de> for WhatAmIMatcher {
379    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
380    where
381        D: serde::Deserializer<'de>,
382    {
383        deserializer.deserialize_seq(WhatAmIMatcherVisitor)
384    }
385}