watermelon_proto/
queue_group.rs

1use alloc::string::String;
2use core::{
3    fmt::{self, Display},
4    ops::Deref,
5};
6use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
7
8use bytestring::ByteString;
9
10/// A string that can be used to represent an queue group
11///
12/// `QueueGroup` contains a string that is guaranteed [^1] to
13/// contain a valid header name that meets the following requirements:
14///
15/// * The value is not empty
16/// * The value has a length less than or equal to 64 [^2]
17/// * The value does not contain any whitespace characters or `:`
18///
19/// `QueueGroup` can be constructed from [`QueueGroup::from_static`]
20/// or any of the `TryFrom` implementations.
21///
22/// [^1]: Because [`QueueGroup::from_dangerous_value`] is safe to call,
23///       unsafe code must not assume any of the above invariants.
24/// [^2]: Messages coming from the NATS server are allowed to violate this rule.
25#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
26pub struct QueueGroup(ByteString);
27
28impl QueueGroup {
29    /// Construct `QueueGroup` from a static string
30    ///
31    /// # Panics
32    ///
33    /// Will panic if `value` isn't a valid `QueueGroup`
34    #[must_use]
35    pub fn from_static(value: &'static str) -> Self {
36        Self::try_from(ByteString::from_static(value)).expect("invalid QueueGroup")
37    }
38
39    /// Construct a `QueueGroup` from a string, without checking invariants
40    ///
41    /// This method bypasses invariants checks implemented by [`QueueGroup::from_static`]
42    /// and all `TryFrom` implementations.
43    ///
44    /// # Security
45    ///
46    /// While calling this method can eliminate the runtime performance cost of
47    /// checking the string, constructing `QueueGroup` with an invalid string and
48    /// then calling the NATS server with it can cause serious security issues.
49    /// When in doubt use the [`QueueGroup::from_static`] or any of the `TryFrom`
50    /// implementations.
51    #[must_use]
52    #[expect(
53        clippy::missing_panics_doc,
54        reason = "The queue group validation is only made in debug"
55    )]
56    pub fn from_dangerous_value(value: ByteString) -> Self {
57        if cfg!(debug_assertions) {
58            if let Err(err) = validate_queue_group(&value) {
59                panic!("QueueGroup {value:?} isn't valid {err:?}");
60            }
61        }
62        Self(value)
63    }
64
65    #[must_use]
66    pub fn as_str(&self) -> &str {
67        &self.0
68    }
69}
70
71impl Display for QueueGroup {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        Display::fmt(&self.0, f)
74    }
75}
76
77impl TryFrom<ByteString> for QueueGroup {
78    type Error = QueueGroupValidateError;
79
80    fn try_from(value: ByteString) -> Result<Self, Self::Error> {
81        validate_queue_group(&value)?;
82        Ok(Self::from_dangerous_value(value))
83    }
84}
85
86impl TryFrom<String> for QueueGroup {
87    type Error = QueueGroupValidateError;
88
89    fn try_from(value: String) -> Result<Self, Self::Error> {
90        validate_queue_group(&value)?;
91        Ok(Self::from_dangerous_value(value.into()))
92    }
93}
94
95impl From<QueueGroup> for ByteString {
96    fn from(value: QueueGroup) -> Self {
97        value.0
98    }
99}
100
101impl AsRef<[u8]> for QueueGroup {
102    fn as_ref(&self) -> &[u8] {
103        self.as_str().as_bytes()
104    }
105}
106
107impl AsRef<str> for QueueGroup {
108    fn as_ref(&self) -> &str {
109        self.as_str()
110    }
111}
112
113impl Deref for QueueGroup {
114    type Target = str;
115
116    fn deref(&self) -> &Self::Target {
117        self.as_str()
118    }
119}
120
121impl Serialize for QueueGroup {
122    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
123        self.as_str().serialize(serializer)
124    }
125}
126
127impl<'de> Deserialize<'de> for QueueGroup {
128    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
129        let s = ByteString::deserialize(deserializer)?;
130        s.try_into().map_err(de::Error::custom)
131    }
132}
133
134/// An error encountered while validating [`QueueGroup`]
135#[derive(Debug, thiserror::Error)]
136#[cfg_attr(test, derive(PartialEq, Eq))]
137pub enum QueueGroupValidateError {
138    /// The value is empty
139    #[error("QueueGroup is empty")]
140    Empty,
141    /// The value has a length greater than 64
142    #[error("QueueGroup is too long")]
143    TooLong,
144    /// The value contains an Unicode whitespace character
145    #[error("QueueGroup contained an illegal whitespace character")]
146    IllegalCharacter,
147}
148
149fn validate_queue_group(queue_group: &str) -> Result<(), QueueGroupValidateError> {
150    if queue_group.is_empty() {
151        return Err(QueueGroupValidateError::Empty);
152    }
153
154    if queue_group.len() > 64 {
155        // This is an arbitrary limit, but I guess the server must also have one
156        return Err(QueueGroupValidateError::TooLong);
157    }
158
159    if queue_group.chars().any(char::is_whitespace) {
160        // The theoretical security limit is just ` `, `\t`, `\r` and `\n`.
161        // Let's be more careful.
162        return Err(QueueGroupValidateError::IllegalCharacter);
163    }
164
165    Ok(())
166}
167
168#[cfg(test)]
169mod tests {
170    use bytestring::ByteString;
171
172    use super::{QueueGroup, QueueGroupValidateError};
173
174    #[test]
175    fn valid_queue_groups() {
176        let queue_groups = ["importer", "importer.thing", "blablabla:itworks"];
177        for queue_group in queue_groups {
178            let q = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap();
179            assert_eq!(queue_group, q.as_str());
180        }
181    }
182
183    #[test]
184    fn invalid_queue_groups() {
185        let queue_groups = [
186            ("", QueueGroupValidateError::Empty),
187            (
188                "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
189                QueueGroupValidateError::TooLong,
190            ),
191            ("importer ", QueueGroupValidateError::IllegalCharacter),
192            ("importer .thing", QueueGroupValidateError::IllegalCharacter),
193            (" importer", QueueGroupValidateError::IllegalCharacter),
194            ("importer.thing ", QueueGroupValidateError::IllegalCharacter),
195            (
196                "importer.thing.works ",
197                QueueGroupValidateError::IllegalCharacter,
198            ),
199            (
200                "importer.thing.works\r",
201                QueueGroupValidateError::IllegalCharacter,
202            ),
203            (
204                "importer.thing.works\n",
205                QueueGroupValidateError::IllegalCharacter,
206            ),
207            (
208                "importer.thing.works\t",
209                QueueGroupValidateError::IllegalCharacter,
210            ),
211            (
212                "importer.thi ng.works",
213                QueueGroupValidateError::IllegalCharacter,
214            ),
215            (
216                "importer.thi\rng.works",
217                QueueGroupValidateError::IllegalCharacter,
218            ),
219            (
220                "importer.thi\nng.works",
221                QueueGroupValidateError::IllegalCharacter,
222            ),
223            (
224                "importer.thi\tng.works",
225                QueueGroupValidateError::IllegalCharacter,
226            ),
227            (
228                "importer.thing .works",
229                QueueGroupValidateError::IllegalCharacter,
230            ),
231            (
232                "importer.thing\r.works",
233                QueueGroupValidateError::IllegalCharacter,
234            ),
235            (
236                "importer.thing\n.works",
237                QueueGroupValidateError::IllegalCharacter,
238            ),
239            (
240                "importer.thing\t.works",
241                QueueGroupValidateError::IllegalCharacter,
242            ),
243            (" ", QueueGroupValidateError::IllegalCharacter),
244            ("\r", QueueGroupValidateError::IllegalCharacter),
245            ("\n", QueueGroupValidateError::IllegalCharacter),
246            ("\t", QueueGroupValidateError::IllegalCharacter),
247        ];
248        for (queue_group, expected_err) in queue_groups {
249            let err = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap_err();
250            assert_eq!(expected_err, err);
251        }
252    }
253}