snops_common/
node_targets.rs

1use core::fmt;
2use std::str::FromStr;
3
4use http::StatusCode;
5use lazy_static::lazy_static;
6use regex::Regex;
7use serde::{
8    de::{Error, Visitor},
9    ser::SerializeSeq,
10    Deserialize, Serialize,
11};
12use thiserror::Error;
13use wildmatch::WildMatch;
14
15use crate::{
16    format::*,
17    impl_into_status_code,
18    state::{NodeKey, NodeType},
19};
20
21#[derive(Debug, Error)]
22#[error("invalid node target string")]
23pub struct NodeTargetError;
24
25impl_into_status_code!(NodeTargetError, |_| StatusCode::BAD_REQUEST);
26
27/// One or more deserialized node targets. Composed of one or more
28/// [`NodeTarget`]s.
29#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)]
30pub enum NodeTargets {
31    #[default]
32    None,
33    One(NodeTarget),
34    Many(Vec<NodeTarget>),
35}
36
37impl DataFormat for NodeTargets {
38    type Header = DataHeaderOf<NodeTarget>;
39    const LATEST_HEADER: Self::Header = NodeTarget::LATEST_HEADER;
40
41    fn write_data<W: std::io::prelude::Write>(
42        &self,
43        writer: &mut W,
44    ) -> Result<usize, DataWriteError> {
45        match self {
46            NodeTargets::None => vec![],
47            NodeTargets::One(target) => vec![target.clone()],
48            NodeTargets::Many(targets) => targets.clone(),
49        }
50        .write_data(writer)
51    }
52
53    fn read_data<R: std::io::prelude::Read>(
54        reader: &mut R,
55        header: &Self::Header,
56    ) -> Result<Self, DataReadError> {
57        let targets = Vec::<NodeTarget>::read_data(reader, header)?;
58        Ok(NodeTargets::from(targets))
59    }
60}
61
62impl<'de> Deserialize<'de> for NodeTargets {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: serde::Deserializer<'de>,
66    {
67        struct NodeTargetsVisitor;
68
69        impl<'de> Visitor<'de> for NodeTargetsVisitor {
70            type Value = NodeTargets;
71
72            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
73                formatter.write_str("one or more node targets")
74            }
75
76            fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
77                if v.contains(',') {
78                    return Ok(NodeTargets::Many(
79                        v.split(',')
80                            .map(|s| NodeTarget::from_str(s.trim()).map_err(E::custom))
81                            .collect::<Result<_, _>>()?,
82                    ));
83                }
84                Ok(NodeTargets::One(FromStr::from_str(v).map_err(E::custom)?))
85            }
86
87            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
88            where
89                A: serde::de::SeqAccess<'de>,
90            {
91                let mut buf = vec![];
92
93                while let Some(elem) = seq.next_element()? {
94                    buf.push(NodeTarget::from_str(elem).map_err(A::Error::custom)?);
95                }
96
97                Ok(if buf.is_empty() {
98                    NodeTargets::None
99                } else {
100                    NodeTargets::Many(buf)
101                })
102            }
103        }
104
105        deserializer.deserialize_any(NodeTargetsVisitor)
106    }
107}
108
109lazy_static! {
110    static ref NODE_TARGET_REGEX: Regex =
111        Regex::new(r"^(?P<ty>\*|client|validator|prover)\/(?P<id>[A-Za-z0-9\-*]+)(?:@(?P<ns>[A-Za-z0-9\-*]+))?$")
112            .unwrap();
113}
114
115impl Serialize for NodeTargets {
116    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117    where
118        S: serde::Serializer,
119    {
120        match self {
121            NodeTargets::None => serializer.serialize_seq(Some(0))?.end(),
122            NodeTargets::One(target) => serializer.serialize_str(&target.to_string()),
123            NodeTargets::Many(targets) => {
124                let mut seq = serializer.serialize_seq(Some(targets.len()))?;
125                for target in targets {
126                    seq.serialize_element(&target.to_string())?;
127                }
128                seq.end()
129            }
130        }
131    }
132}
133
134impl fmt::Display for NodeTargets {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            NodeTargets::None => write!(f, ""),
138            NodeTargets::One(target) => write!(f, "{}", target),
139            NodeTargets::Many(targets) => {
140                let mut iter = targets.iter();
141                if let Some(target) = iter.next() {
142                    write!(f, "{}", target)?;
143                    for target in iter {
144                        write!(f, ", {}", target)?;
145                    }
146                }
147                Ok(())
148            }
149        }
150    }
151}
152
153impl NodeTargets {
154    pub const ALL: Self = Self::One(NodeTarget::ALL);
155
156    pub fn is_all(&self) -> bool {
157        if matches!(self, NodeTargets::One(NodeTarget::ALL)) {
158            return true;
159        }
160
161        if let NodeTargets::Many(targets) = self {
162            return targets.iter().any(|target| target == &NodeTarget::ALL);
163        }
164
165        false
166    }
167}
168
169/// A **single** matched node target. Use [`NodeTargets`] when deserializing
170/// from documents.
171#[derive(Debug, Clone, PartialEq, Hash, Eq)]
172pub struct NodeTarget {
173    pub ty: NodeTargetType,
174    pub id: NodeTargetId,
175    pub ns: NodeTargetNamespace,
176}
177
178impl FromStr for NodeTarget {
179    type Err = NodeTargetError;
180
181    fn from_str(s: &str) -> Result<Self, Self::Err> {
182        let captures = NODE_TARGET_REGEX.captures(s).ok_or(NodeTargetError)?;
183
184        // match the type
185        let ty = match &captures["ty"] {
186            "*" => NodeTargetType::All,
187            "client" => NodeTargetType::One(NodeType::Client),
188            "validator" => NodeTargetType::One(NodeType::Validator),
189            "prover" => NodeTargetType::One(NodeType::Prover),
190            _ => unreachable!(),
191        };
192
193        // match the node ID
194        let id = match &captures["id"] {
195            // full wildcard
196            "*" => NodeTargetId::All,
197
198            // partial wildcard
199            id if id.contains('*') => NodeTargetId::WildcardPattern(WildMatch::new(id)),
200
201            // literal string
202            id => NodeTargetId::Literal(id.into()),
203        };
204
205        // match the namespace
206        let ns = match captures.name("ns") {
207            // full wildcard
208            Some(id) if id.as_str() == "*" => NodeTargetNamespace::All,
209
210            // local; either explicitly stated, or empty
211            Some(id) if id.as_str() == "local" => NodeTargetNamespace::Local,
212            None => NodeTargetNamespace::Local,
213
214            // literal namespace
215            Some(id) => NodeTargetNamespace::Literal(id.as_str().into()),
216        };
217
218        Ok(Self { ty, id, ns })
219    }
220}
221
222impl fmt::Display for NodeTarget {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        write!(
225            f,
226            "{}/{}{}",
227            match self.ty {
228                NodeTargetType::All => "*".to_owned(),
229                NodeTargetType::One(ty) => ty.to_string(),
230            },
231            match &self.id {
232                NodeTargetId::All => "*".to_owned(),
233                NodeTargetId::WildcardPattern(pattern) => pattern.to_string(),
234                NodeTargetId::Literal(id) => id.to_owned(),
235            },
236            match &self.ns {
237                NodeTargetNamespace::All => "@*".to_owned(),
238                NodeTargetNamespace::Local => "".to_owned(),
239                NodeTargetNamespace::Literal(ns) => format!("@{}", ns),
240            }
241        )
242    }
243}
244
245impl Serialize for NodeTarget {
246    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
247    where
248        S: serde::Serializer,
249    {
250        serializer.serialize_str(&self.to_string())
251    }
252}
253
254impl<'de> Deserialize<'de> for NodeTarget {
255    fn deserialize<D>(deserializer: D) -> Result<NodeTarget, D::Error>
256    where
257        D: serde::Deserializer<'de>,
258    {
259        let s = String::deserialize(deserializer)?;
260        NodeTarget::from_str(&s).map_err(D::Error::custom)
261    }
262}
263
264impl DataFormat for NodeTarget {
265    type Header = (u8, DataHeaderOf<NodeType>);
266    const LATEST_HEADER: Self::Header = (1, NodeType::LATEST_HEADER);
267
268    fn write_data<W: std::io::prelude::Write>(
269        &self,
270        writer: &mut W,
271    ) -> Result<usize, DataWriteError> {
272        let mut written = 0;
273        written += match self.ty {
274            NodeTargetType::All => 0u8.write_data(writer)?,
275            NodeTargetType::One(ty) => 1u8.write_data(writer)? + ty.write_data(writer)?,
276        };
277        written += match &self.id {
278            NodeTargetId::All => 0u8.write_data(writer)?,
279            NodeTargetId::WildcardPattern(pattern) => {
280                1u8.write_data(writer)? + pattern.to_string().write_data(writer)?
281            }
282            NodeTargetId::Literal(id) => 2u8.write_data(writer)? + id.write_data(writer)?,
283        };
284        written += match &self.ns {
285            NodeTargetNamespace::All => 0u8.write_data(writer)?,
286            NodeTargetNamespace::Local => 1u8.write_data(writer)?,
287            NodeTargetNamespace::Literal(ns) => 2u8.write_data(writer)? + ns.write_data(writer)?,
288        };
289
290        Ok(written)
291    }
292
293    fn read_data<R: std::io::prelude::Read>(
294        reader: &mut R,
295        header: &Self::Header,
296    ) -> Result<Self, DataReadError> {
297        if header.0 != Self::LATEST_HEADER.0 {
298            return Err(DataReadError::unsupported(
299                "NodeTarget",
300                Self::LATEST_HEADER.0,
301                header.0,
302            ));
303        }
304
305        let ty = match reader.read_data(&())? {
306            0u8 => NodeTargetType::All,
307            1u8 => NodeTargetType::One(NodeType::read_data(reader, &header.1)?),
308            n => {
309                return Err(DataReadError::Custom(format!(
310                    "invalid NodeTarget type discriminant: {n}"
311                )))
312            }
313        };
314
315        let id = match reader.read_data(&())? {
316            0u8 => NodeTargetId::All,
317            1u8 => {
318                let pattern = String::read_data(reader, &())?;
319                NodeTargetId::WildcardPattern(WildMatch::new(&pattern))
320            }
321            2u8 => NodeTargetId::Literal(reader.read_data(&())?),
322            n => {
323                return Err(DataReadError::Custom(format!(
324                    "invalid NodeTarget ID discriminant: {n}"
325                )))
326            }
327        };
328
329        let ns = match reader.read_data(&())? {
330            0u8 => NodeTargetNamespace::All,
331            1u8 => NodeTargetNamespace::Local,
332            2u8 => NodeTargetNamespace::Literal(reader.read_data(&())?),
333            n => {
334                return Err(DataReadError::Custom(format!(
335                    "invalid NodeTarget namespace discriminant: {n}"
336                )))
337            }
338        };
339
340        Ok(Self { ty, id, ns })
341    }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
345pub enum NodeTargetType {
346    /// Matches all node types.
347    All,
348    /// Matches a particular node type.
349    One(NodeType),
350}
351
352#[derive(Debug, Clone, PartialEq)]
353pub enum NodeTargetId {
354    /// `*`. Matches all IDs.
355    All,
356    /// A wildcard pattern, like `foo-*`.
357    WildcardPattern(WildMatch),
358    /// A literal name, like `foo-node` or `1`.
359    Literal(String),
360}
361
362impl Eq for NodeTargetId {}
363
364impl std::hash::Hash for NodeTargetId {
365    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
366        match self {
367            NodeTargetId::All => "*".hash(state),
368            NodeTargetId::WildcardPattern(pattern) => pattern.to_string().hash(state),
369            NodeTargetId::Literal(id) => id.hash(state),
370        }
371    }
372}
373
374#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
375pub enum NodeTargetNamespace {
376    /// `*`. Matches all namespaces.
377    All,
378    /// A literal name, like `mainnet`.
379    Literal(String),
380    /// The local namespace.
381    Local,
382}
383
384impl From<NodeKey> for NodeTarget {
385    fn from(value: NodeKey) -> Self {
386        Self {
387            ty: NodeTargetType::One(value.ty),
388            id: NodeTargetId::Literal(value.id),
389            ns: value
390                .ns
391                .map(NodeTargetNamespace::Literal)
392                .unwrap_or(NodeTargetNamespace::Local),
393        }
394    }
395}
396
397impl From<Vec<NodeTarget>> for NodeTargets {
398    fn from(nodes: Vec<NodeTarget>) -> Self {
399        match nodes.len() {
400            0 => Self::None,
401            1 => Self::One(nodes.into_iter().next().unwrap()),
402            _ => Self::Many(nodes),
403        }
404    }
405}
406
407impl NodeTarget {
408    pub const ALL: Self = Self {
409        ty: NodeTargetType::All,
410        id: NodeTargetId::All,
411        ns: NodeTargetNamespace::All,
412    };
413
414    pub fn matches(&self, key: &NodeKey) -> bool {
415        (match self.ty {
416            NodeTargetType::All => true,
417            NodeTargetType::One(ty) => ty == key.ty,
418        }) && (match &self.id {
419            NodeTargetId::All => true,
420            NodeTargetId::WildcardPattern(pattern) => pattern.matches(&key.id),
421            NodeTargetId::Literal(id) => &key.id == id,
422        }) && (match &self.ns {
423            NodeTargetNamespace::All => true,
424            NodeTargetNamespace::Local => key.ns.is_none() || key.ns == Some("local".into()),
425            NodeTargetNamespace::Literal(ns) => {
426                ns == "local" && key.ns.is_none()
427                    || key.ns.as_ref().map_or(false, |key_ns| key_ns == ns)
428            }
429        })
430    }
431}
432
433impl NodeTargets {
434    pub fn is_empty(&self) -> bool {
435        if matches!(self, &NodeTargets::None) {
436            return true;
437        }
438
439        if let NodeTargets::Many(targets) = self {
440            return targets.is_empty();
441        }
442
443        false
444    }
445
446    pub fn matches(&self, key: &NodeKey) -> bool {
447        match self {
448            NodeTargets::None => false,
449            NodeTargets::One(target) => target.matches(key),
450            NodeTargets::Many(targets) => targets.iter().any(|target| target.matches(key)),
451        }
452    }
453}