stomp_parser/model/headers/
mod.rs

1//! Implements the model for headers, as specified in the
2//! [STOMP Protocol Specification,Version 1.2](https://stomp.github.io/stomp-specification-1.2.html).
3#[macro_use]
4mod macros;
5use crate::common::functions::decode_str;
6use crate::error::StompParseError;
7use either::Either;
8use paste::paste;
9use std::convert::TryFrom;
10use std::str::FromStr;
11
12/// A Header that reveals it's type and it's value, and can be displayed
13pub trait HeaderValue: std::fmt::Display {
14    type OwnedValue;
15    type Value;
16    const OWNED: bool;
17
18    fn header_name(&self) -> &str;
19}
20
21pub trait DecodableValue {
22    fn decoded_value(&self) -> Result<Either<&str, String>, StompParseError>;
23}
24#[derive(Eq, PartialEq, Debug, Clone)]
25pub struct NameValue {
26    pub name: String,
27    pub value: String,
28}
29
30impl std::fmt::Display for NameValue {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
32        write!(f, "{}:{}", &self.name, &self.value)
33    }
34}
35
36fn split_once(input: &str, delim: char) -> Option<(&str, &str)> {
37    input
38        .find(delim)
39        .map(|idx| (&input[0..idx], &input[(idx + 1)..input.len()]))
40}
41
42impl FromStr for NameValue {
43    type Err = StompParseError;
44    fn from_str(input: &str) -> Result<NameValue, StompParseError> {
45        split_once(input, ':')
46            .map(|(name, value)| NameValue {
47                name: name.to_owned(),
48                value: value.to_owned(),
49            })
50            .ok_or_else(|| StompParseError::new(format!("Poorly formatted header: {}", input)))
51    }
52}
53
54/// A pair of numbers which specify at what intervall the originator of
55/// the containing message will supply a heartbeat and expect a heartbeat.
56#[derive(Eq, PartialEq, Debug, Clone, Default)]
57pub struct HeartBeatIntervals {
58    pub supplied: u32,
59    pub expected: u32,
60}
61
62impl HeartBeatIntervals {
63    pub fn new(supplied: u32, expected: u32) -> HeartBeatIntervals {
64        HeartBeatIntervals { expected, supplied }
65    }
66}
67
68impl std::fmt::Display for HeartBeatIntervals {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
70        write!(f, "{},{}", &self.supplied, &self.expected)
71    }
72}
73
74impl FromStr for HeartBeatIntervals {
75    type Err = StompParseError;
76    /// Parses the string message as two ints representing "supplied, expected" heartbeat intervalls
77    fn from_str(input: &str) -> Result<HeartBeatIntervals, StompParseError> {
78        split_once(input, ',')
79            .ok_or_else(|| StompParseError::new(format!("Poorly formatted heartbeats: {}", input)))
80            .and_then(|(supplied, expected)| {
81                u32::from_str(expected)
82                    .and_then(|expected| {
83                        u32::from_str(supplied)
84                            .map(|supplied| HeartBeatIntervals { expected, supplied })
85                    })
86                    .map_err(|_| {
87                        StompParseError::new(format!("Poorly formatted heartbeats: {}", input))
88                    })
89            })
90    }
91}
92
93#[derive(Eq, PartialEq, Debug, Clone)]
94pub struct StompVersions(pub Vec<StompVersion>);
95
96impl std::fmt::Display for StompVersions {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
98        write!(
99            f,
100            "{}",
101            self.0
102                .iter()
103                .map(|version| version.to_string())
104                .collect::<Vec<String>>()
105                .join(",")
106        )
107    }
108}
109
110impl FromStr for StompVersions {
111    type Err = StompParseError;
112    fn from_str(input: &str) -> Result<StompVersions, StompParseError> {
113        input
114            .split(',')
115            .map(|section| StompVersion::from_str(section))
116            .try_fold(Vec::new(), |mut vec, result| {
117                result
118                    .map(|version| {
119                        vec.push(version);
120                        vec
121                    })
122                    .map_err(|_| {
123                        StompParseError::new(format!("Poorly formatted accept-versions: {}", input))
124                    })
125            })
126            .map(StompVersions)
127    }
128}
129
130impl std::ops::Deref for StompVersions {
131    type Target = Vec<StompVersion>;
132
133    fn deref(&self) -> &Self::Target {
134        &self.0
135    }
136}
137
138#[derive(Eq, PartialEq, Debug, Clone)]
139/// The Ack approach to be used for the subscription
140pub enum AckType {
141    /// The client need not send Acks. Messages are assumed received as soon as sent.
142    Auto,
143    /// Client must send Ack frames. Ack frames are cummulative, acknowledging also all previous messages.
144    Client,
145    /// Client must send Ack frames. Ack frames are individual, acknowledging only the specified message.
146    ClientIndividual,
147}
148
149impl Default for AckType {
150    fn default() -> Self {
151        AckType::Auto
152    }
153}
154
155impl std::fmt::Display for AckType {
156    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157        f.write_str(match self {
158            AckType::Auto => "auto",
159            AckType::Client => "client",
160            AckType::ClientIndividual => "client-individual",
161        })
162    }
163}
164
165impl FromStr for AckType {
166    type Err = StompParseError;
167    fn from_str(input: &str) -> Result<AckType, StompParseError> {
168        match input {
169            "auto" => Ok(AckType::Auto),
170            "client" => Ok(AckType::Client),
171            "client-individual" => Ok(AckType::ClientIndividual),
172            _ => Err(StompParseError::new(format!("Unknown ack-type: {}", input))),
173        }
174    }
175}
176
177#[allow(non_camel_case_types)]
178#[derive(Debug, Eq, PartialEq, Clone)]
179/// Stomp Versions that client and server can negotiate to use
180pub enum StompVersion {
181    V1_0,
182    V1_1,
183    V1_2,
184    Unknown(String),
185}
186
187impl std::fmt::Display for StompVersion {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
189        let text = match self {
190            StompVersion::V1_0 => "1.0",
191            StompVersion::V1_1 => "1.1",
192            StompVersion::V1_2 => "1.2",
193            _ => return Err(std::fmt::Error {}),
194        };
195        f.write_str(text)
196    }
197}
198
199impl FromStr for StompVersion {
200    type Err = StompParseError;
201    fn from_str(input: &str) -> Result<StompVersion, StompParseError> {
202        match input {
203            "1.0" => Ok(StompVersion::V1_0),
204            "1.1" => Ok(StompVersion::V1_1),
205            "1.2" => Ok(StompVersion::V1_2),
206            _ => Ok(StompVersion::Unknown(input.to_owned())),
207        }
208    }
209}
210
211const EMPTY: &str = "";
212
213headers!(
214    (Ack, "ack", AckType, (AckType::Auto)),
215    (
216        AcceptVersion,
217        "accept-version",
218        StompVersions,
219        (StompVersions(Vec::new()))
220    ),
221    (ContentLength, "content-length", u32, 0),
222    (ContentType, "content-type"),
223    (Destination, "destination"),
224    (
225        HeartBeat,
226        "heart-beat",
227        HeartBeatIntervals,
228        (HeartBeatIntervals::new(0, 0))
229    ),
230    (Host, "host"),
231    (Id, "id"),
232    (Login, "login"),
233    (Message, "message"),
234    (MessageId, "message-id"),
235    (Passcode, "passcode"),
236    (Receipt, "receipt"),
237    (ReceiptId, "receipt-id"),
238    (Server, "server"),
239    (Session, "session"),
240    (Subscription, "subscription"),
241    (Transaction, "transaction"),
242    (Version, "version", StompVersion, (StompVersion::V1_2))
243);
244
245#[cfg(test)]
246mod test {
247    use crate::common::functions::decode_str;
248    use crate::error::StompParseError;
249    use crate::headers::{HeartBeatIntervals, HeartBeatValue};
250    use either::Either;
251
252    use std::{fmt::Display, str::FromStr};
253
254    use super::{ContentLengthValue, DecodableValue, DestinationValue, HeaderValue};
255
256    fn do_something(value: &str) {
257        println!("Value: {}", value);
258    }
259
260    #[test]
261    fn header_value() {
262        let d = DestinationValue::new("Foo");
263
264        let value: &str = d.value();
265
266        do_something(value);
267
268        drop(d);
269
270        //        println!("Value: {}", value);
271    }
272
273    #[test]
274    fn header_value_display() {
275        let x = ContentLengthValue::new(10);
276
277        assert_eq!("content-length:10", x.to_string())
278    }
279
280    #[test]
281    fn heartbeat_reads_supplied_then_expected() {
282        let hb = HeartBeatIntervals::from_str("100,200").expect("Heartbeat parse failed");
283
284        assert_eq!(100, hb.supplied);
285        assert_eq!(200, hb.expected);
286    }
287
288    #[test]
289    fn heartbeat_writes_supplied_then_expected() {
290        let hb = HeartBeatIntervals::new(500, 300);
291
292        assert_eq!("500,300", hb.to_string());
293    }
294
295    #[test]
296    fn heartbeat_into_intervalls() {
297        let hb = HeartBeatValue::new(HeartBeatIntervals::new(123, 987));
298
299        let intervalls: HeartBeatIntervals = hb.into();
300
301        assert_eq!(123, intervalls.supplied);
302        assert_eq!(987, intervalls.expected);
303    }
304
305    struct TestValue {
306        value: &'static str,
307    }
308
309    impl TestValue {
310        fn value(&self) -> &str {
311            self.value
312        }
313    }
314
315    impl DecodableValue for TestValue {
316        fn decoded_value(&self) -> Result<Either<&str, String>, StompParseError> {
317            decode_str(self.value())
318        }
319    }
320
321    impl Display for TestValue {
322        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323            f.write_fmt(format_args!("test:{}", self.value))
324        }
325    }
326
327    impl HeaderValue for TestValue {
328        type OwnedValue = String;
329        type Value = &'static str;
330        const OWNED: bool = false;
331
332        fn header_name(&self) -> &str {
333            todo!()
334        }
335    }
336
337    #[test]
338    fn returns_value_if_no_escape() {
339        let value = "Hello";
340        let instance = TestValue { value };
341
342        let result = instance.decoded_value();
343
344        if let Ok(Either::Left(result)) = result {
345            assert_eq!(value.as_ptr(), result.as_ptr());
346        } else {
347            panic!("Unexpected return");
348        }
349    }
350
351    #[test]
352    fn transforms_escaped_slash() {
353        let value = "Hel\\\\lo";
354        let instance = TestValue { value };
355
356        let result = instance.decoded_value();
357
358        if let Ok(Either::Right(result)) = result {
359            assert_eq!("Hel\\lo", &result);
360        } else {
361            panic!("Unexpected return");
362        }
363    }
364
365    #[test]
366    fn transforms_escaped_n() {
367        let value = "Hell\\nno";
368        let instance = TestValue { value };
369
370        let result = instance.decoded_value();
371
372        if let Ok(Either::Right(result)) = result {
373            assert_eq!("Hell\nno", &result);
374        } else {
375            panic!("Unexpected return");
376        }
377    }
378
379    #[test]
380    fn transforms_escaped_r() {
381        let value = "Hell\\rno";
382        let instance = TestValue { value };
383
384        let result = instance.decoded_value();
385
386        if let Ok(Either::Right(result)) = result {
387            assert_eq!("Hell\rno", &result);
388        } else {
389            panic!("Unexpected return");
390        }
391    }
392
393    #[test]
394    fn transforms_escaped_c() {
395        let value = "Hell\\cno";
396        let instance = TestValue { value };
397
398        let result = instance.decoded_value();
399
400        if let Ok(Either::Right(result)) = result {
401            assert_eq!("Hell:no", &result);
402        } else {
403            panic!("Unexpected return");
404        }
405    }
406
407    #[test]
408    fn rejects_escaped_t() {
409        let value = "Hell\\tno";
410        let instance = TestValue { value };
411
412        let result = instance.decoded_value();
413
414        if let Ok(_) = result {
415            panic!("Unexpected return");
416        }
417    }
418
419    #[test]
420    fn rejects_slash_at_end() {
421        let value = "Hell\\";
422        let instance = TestValue { value };
423
424        let result = instance.decoded_value();
425
426        if let Ok(_) = result {
427            panic!("Unexpected return");
428        }
429    }
430}