stomp_rs/protocol/
mod.rs

1pub mod frame;
2
3use std::collections::HashMap;
4use std::convert::TryFrom;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7
8#[derive(Debug, PartialEq, Copy, Clone)]
9pub enum ClientCommand {
10    Connect,
11    Send,
12    Subscribe,
13    Unsubscribe,
14    Ack,
15    Nack,
16    Begin,
17    Commit,
18    Abort,
19    Disconnect,
20}
21
22#[derive(Debug, PartialEq, Copy, Clone)]
23pub enum ServerCommand {
24    Connected,
25    Message,
26    Receipt,
27    Error,
28}
29
30impl From<ServerCommand> for &str {
31    fn from(value: ServerCommand) -> Self {
32        match value {
33            ServerCommand::Connected => "CONNECTED",
34            ServerCommand::Message => "MESSAGE",
35            ServerCommand::Receipt => "RECEIPT",
36            ServerCommand::Error => "ERROR",
37        }
38    }
39}
40
41impl From<ClientCommand> for &str {
42    fn from(value: ClientCommand) -> Self {
43        match value {
44            ClientCommand::Connect => "CONNECT",
45            ClientCommand::Send => "SEND",
46            ClientCommand::Subscribe => "SUBSCRIBE",
47            ClientCommand::Unsubscribe => "UNSUBSCRIBE",
48            ClientCommand::Ack => "ACK",
49            ClientCommand::Nack => "NACK",
50            ClientCommand::Begin => "BEGIN",
51            ClientCommand::Commit => "COMMIT",
52            ClientCommand::Abort => "ABORT",
53            ClientCommand::Disconnect => "DISCONNECT",
54        }
55    }
56}
57
58impl TryFrom<&str> for ClientCommand {
59    type Error = &'static str;
60
61    fn try_from(value: &str) -> Result<Self, Self::Error> {
62        match value {
63            "CONNECT" => Ok(ClientCommand::Connect),
64            "SEND" => Ok(ClientCommand::Send),
65            "SUBSCRIBE" => Ok(ClientCommand::Subscribe),
66            "UNSUBSCRIBE" => Ok(ClientCommand::Unsubscribe),
67            "ACK" => Ok(ClientCommand::Ack),
68            "NACK" => Ok(ClientCommand::Nack),
69            "BEGIN" => Ok(ClientCommand::Begin),
70            "COMMIT" => Ok(ClientCommand::Commit),
71            "ABORT" => Ok(ClientCommand::Abort),
72            "DISCONNECT" => Ok(ClientCommand::Disconnect),
73            _ => Err("Unknown client command"),
74        }
75    }
76}
77
78impl TryFrom<&str> for ServerCommand {
79    type Error = &'static str;
80
81    fn try_from(value: &str) -> Result<Self, <ServerCommand as TryFrom<&'static str>>::Error> {
82        match value {
83            "CONNECTED" => Ok(ServerCommand::Connected),
84            "MESSAGE" => Ok(ServerCommand::Message),
85            "RECEIPT" => Ok(ServerCommand::Receipt),
86            "ERROR" => Ok(ServerCommand::Error),
87            _ => Err("Unknown client command"),
88        }
89    }
90}
91
92impl Command for ServerCommand {}
93
94impl Command for ClientCommand {}
95
96#[derive(Debug, Clone)]
97pub struct Frame<T>
98where
99    T: Into<&'static str>,
100{
101    pub command: T,
102    pub headers: HashMap<String, String>,
103    pub body: String,
104}
105
106impl<T> Frame<T>
107where
108    T: Into<&'static str> + Copy,
109{
110    pub fn to_bytes(&self) -> Vec<u8> {
111        let mut buffer = vec![];
112
113        buffer.extend_from_slice(self.command.into().as_bytes());
114        buffer.push(BNF_LF);
115
116        self.headers.iter().for_each(|entry| {
117            buffer.extend_from_slice(entry.0.as_bytes());
118            buffer.extend_from_slice(":".as_bytes());
119            buffer.extend_from_slice(entry.1.as_bytes());
120            buffer.push(BNF_LF)
121        });
122
123        buffer.push(BNF_LF);
124        buffer.extend_from_slice(self.body.as_bytes());
125        buffer.push(BNF_NULL);
126
127        buffer
128    }
129}
130
131#[derive(PartialEq)]
132enum ReadingState {
133    Command,
134    Header,
135    Body,
136    Completed,
137}
138
139const BNF_NULL: u8 = 0;
140pub(crate) const BNF_LF: u8 = 10;
141const BNF_CR: u8 = 13;
142
143pub trait Command: Into<&'static str> + for<'a> TryFrom<&'a str> {}
144
145pub struct FrameParser<T: Command> {
146    buffer: Vec<u8>,
147    state: ReadingState,
148
149    current_command: Option<T>,
150    current_headers: Option<HashMap<String, String>>,
151}
152
153#[derive(Debug, Clone)]
154pub enum StompMessage<T: Command + Clone> {
155    Frame(Frame<T>),
156    Ping,
157}
158
159#[derive(Debug)]
160pub enum ParseError {
161    CommandNotFound(String),
162}
163
164impl Display for ParseError {
165    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166        write!(f, "Parsing error")
167    }
168}
169
170impl Error for ParseError {}
171
172impl<T: Command + Clone> FrameParser<T> {
173    pub fn new() -> FrameParser<T> {
174        FrameParser {
175            buffer: vec![],
176            state: ReadingState::Command,
177            current_command: None,
178            current_headers: None,
179        }
180    }
181
182    pub fn parse(&mut self, body: &[u8]) -> Result<Vec<StompMessage<T>>, ParseError> {
183        let mut frames = vec![];
184
185        let mut body_slice = &body[..];
186
187        loop {
188            let collect_until = match self.state {
189                ReadingState::Command => BNF_LF,
190                ReadingState::Header => BNF_LF,
191                ReadingState::Body => BNF_NULL,
192                ReadingState::Completed => BNF_LF,
193            };
194
195            let position = &body_slice.iter().position(|b| *b == collect_until);
196
197            match position {
198                Some(position) => {
199                    let previous_position = position.saturating_sub(1_usize);
200
201                    let buffer_until = if collect_until == BNF_LF
202                        && body_slice
203                            .get(previous_position)
204                            .iter()
205                            .all(|b| **b == BNF_CR)
206                    {
207                        previous_position
208                    } else {
209                        *position
210                    };
211
212                    self.buffer.extend(&body_slice[..buffer_until]);
213                    body_slice = &body_slice[(u32::try_from(*position).unwrap() + 1) as usize..];
214                }
215                None => {
216                    self.buffer.extend(&body_slice[..]);
217                    break;
218                }
219            }
220
221            if let ReadingState::Completed = self.state {
222                if !self.buffer.is_empty() && self.buffer.iter().any(|b| *b != BNF_CR) {
223                    self.state = ReadingState::Command;
224                }
225            }
226
227            match self.state {
228                ReadingState::Command => {
229                    let buffer = std::mem::take(&mut self.buffer);
230                    let command_string = String::from_utf8(buffer).unwrap();
231
232                    let command = T::try_from(&command_string);
233
234                    self.current_command = match command {
235                        Ok(value) => Some(value),
236                        Err(_) => {
237                            return Err(ParseError::CommandNotFound(command_string.to_string()));
238                        }
239                    };
240
241                    self.state = ReadingState::Header;
242                    self.current_headers = Some(HashMap::new());
243                }
244                ReadingState::Header => {
245                    if self.buffer.is_empty() {
246                        self.state = ReadingState::Body;
247                    } else {
248                        let buffer = std::mem::take(&mut self.buffer);
249                        let header_line = String::from_utf8(buffer).unwrap();
250                        let mut header = header_line.split(':');
251                        self.current_headers.as_mut().unwrap().insert(
252                            header.next().unwrap().trim().to_string(),
253                            header.next().unwrap().trim().to_string(),
254                        );
255                    }
256                }
257                ReadingState::Body => {
258                    let buffer = std::mem::take(&mut self.buffer);
259                    let body = String::from_utf8(buffer).unwrap();
260
261                    self.state = ReadingState::Completed;
262
263                    let frame_command = std::mem::take(&mut self.current_command);
264                    let frame_headers = std::mem::take(&mut self.current_headers);
265
266                    frames.push(StompMessage::Frame(Frame {
267                        command: frame_command.unwrap(),
268                        headers: frame_headers.unwrap(),
269                        body,
270                    }));
271                }
272                ReadingState::Completed => {
273                    frames.push(StompMessage::Ping);
274                    self.buffer.clear();
275                }
276            }
277        }
278
279        Ok(frames)
280    }
281}
282
283#[cfg(test)]
284mod test {
285    use crate::protocol::{ClientCommand, FrameParser, StompMessage};
286
287    #[tokio::test]
288    async fn parse_test() {
289        let body = "SEND\n\
290        test: value\n\
291        test_val: heeerre\n\
292        \n\
293        body\n\
294        first body\0\n\n\
295        \n\
296        \n\
297        SEND\n\
298        test2: value\n\
299        \n\
300        body : test\n\
301        second body\0
302        "
303        .as_bytes();
304
305        let mut frames = vec![];
306        let mut parser: FrameParser<ClientCommand> = FrameParser::new();
307
308        for body_chunk in body.chunks(4) {
309            frames.append(&mut parser.parse(body_chunk).unwrap());
310        }
311
312        let frame = frames.first();
313
314        assert!(frame.is_some());
315        let frame = frame.unwrap();
316
317        if let StompMessage::Frame(frame) = frame {
318            assert_eq!(frame.command, ClientCommand::Send);
319            let headers = &frame.headers;
320
321            assert!(headers.contains_key("test"));
322            assert_eq!(headers.get("test").unwrap(), "value");
323            assert!(headers.contains_key("test_val"));
324            assert_eq!(headers.get("test_val").unwrap(), "heeerre");
325
326            assert_eq!(
327                frame.body,
328                "body\n\
329first body"
330            );
331            println!("{:?}", frame);
332        }
333    }
334
335    #[tokio::test]
336    async fn parse_test_cr() {
337        let body = "SEND\r\n\
338        test: value\r\n\
339        test_val: heeerre\r\n\
340        \r\n\
341        body\r\n\
342        first body\0\n\n\
343        \r\n\
344        \r\n\
345        SEND\n\
346        test2: value\n\
347        \n\
348        body : test\n\
349        second body\0
350        "
351        .as_bytes();
352
353        let mut frames = vec![];
354        let mut parser: FrameParser<ClientCommand> = FrameParser::new();
355
356        for body_chunk in body.chunks(4) {
357            frames.append(&mut parser.parse(body_chunk).unwrap());
358        }
359
360        let frame = frames.first();
361
362        assert!(frame.is_some());
363        let frame = frame.unwrap();
364
365        if let StompMessage::Frame(frame) = frame {
366            assert_eq!(frame.command, ClientCommand::Send);
367            let headers = &frame.headers;
368
369            assert!(headers.contains_key("test"));
370            assert_eq!(headers.get("test").unwrap(), "value");
371            assert!(headers.contains_key("test_val"));
372            assert_eq!(headers.get("test_val").unwrap(), "heeerre");
373
374            assert_eq!(
375                frame.body,
376                "body\r\n\
377first body"
378            );
379            println!("{:?}", frame);
380        }
381    }
382}