watermelon_proto/proto/decoder/
mod.rs

1use core::{mem, ops::Deref};
2
3use bytes::{Buf, Bytes, BytesMut};
4use bytestring::ByteString;
5
6use crate::{
7    MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId,
8    error::ServerError,
9    headers::{
10        HeaderMap, HeaderName, HeaderValue,
11        error::{HeaderNameValidateError, HeaderValueValidateError},
12    },
13    status_code::StatusCodeError,
14    util::{self, CrlfFinder, ParseUintError},
15};
16
17pub use self::framed::{FrameDecoderError, decode_frame};
18pub use self::stream::StreamDecoder;
19
20use super::ServerOp;
21
22mod framed;
23mod stream;
24
25const MAX_HEAD_LEN: usize = 16 * 1024;
26
27#[derive(Debug)]
28pub(super) enum DecoderStatus {
29    ControlLine {
30        last_bytes_read: usize,
31    },
32    Headers {
33        subscription_id: SubscriptionId,
34        subject: Subject,
35        reply_subject: Option<Subject>,
36        header_len: usize,
37        payload_len: usize,
38    },
39    Payload {
40        subscription_id: SubscriptionId,
41        subject: Subject,
42        reply_subject: Option<Subject>,
43        status_code: Option<StatusCode>,
44        headers: HeaderMap,
45        payload_len: usize,
46    },
47    Poisoned,
48}
49
50pub(super) trait BytesLike: Buf + Deref<Target = [u8]> {
51    fn len(&self) -> usize {
52        Buf::remaining(self)
53    }
54
55    fn split_to(&mut self, at: usize) -> Bytes {
56        self.copy_to_bytes(at)
57    }
58}
59
60impl BytesLike for Bytes {}
61impl BytesLike for BytesMut {}
62
63pub(super) fn decode(
64    crlf: &CrlfFinder,
65    status: &mut DecoderStatus,
66    read_buf: &mut impl BytesLike,
67) -> Result<Option<ServerOp>, DecoderError> {
68    loop {
69        match status {
70            DecoderStatus::ControlLine { last_bytes_read } => {
71                if read_buf.starts_with(b"+OK\r\n") {
72                    // Fast path for handling `+OK`
73                    debug_assert_eq!(
74                        *last_bytes_read, 0,
75                        "we shouldn't have handled any bytes before"
76                    );
77                    read_buf.advance("+OK\r\n".len());
78                    return Ok(Some(ServerOp::Success));
79                }
80
81                if *last_bytes_read == read_buf.len() {
82                    // No progress has been made
83                    return Ok(None);
84                }
85
86                let Some(control_line_len) = crlf.find(read_buf) else {
87                    return if read_buf.len() < MAX_HEAD_LEN {
88                        *last_bytes_read = read_buf.len();
89                        Ok(None)
90                    } else {
91                        Err(DecoderError::HeadTooLong {
92                            len: read_buf.len(),
93                        })
94                    };
95                };
96
97                let mut control_line = read_buf.split_to(control_line_len + "\r\n".len());
98                control_line.truncate(control_line.len() - 2);
99
100                return if control_line.starts_with(b"MSG ") {
101                    *status = decode_msg(control_line)?;
102                    continue;
103                } else if control_line.starts_with(b"HMSG ") {
104                    *status = decode_hmsg(control_line)?;
105                    continue;
106                } else if control_line.starts_with(b"PING") {
107                    Ok(Some(ServerOp::Ping))
108                } else if control_line.starts_with(b"PONG") {
109                    Ok(Some(ServerOp::Pong))
110                } else if control_line.starts_with(b"+OK") {
111                    // Slow path for handling `+OK`
112                    Ok(Some(ServerOp::Success))
113                } else if control_line.starts_with(b"-ERR ") {
114                    control_line.advance("-ERR ".len());
115                    if !control_line.starts_with(b"'") || !control_line.ends_with(b"'") {
116                        return Err(DecoderError::InvalidErrorMessage);
117                    }
118
119                    control_line.advance(1);
120                    control_line.truncate(control_line.len() - 1);
121                    let raw_message = ByteString::try_from(control_line)
122                        .map_err(|_| DecoderError::InvalidErrorMessage)?;
123                    let error = ServerError::parse(raw_message);
124                    Ok(Some(ServerOp::Error { error }))
125                } else if let Some(info) = control_line.strip_prefix(b"INFO ") {
126                    let info = serde_json::from_slice(info).map_err(DecoderError::InvalidInfo)?;
127                    Ok(Some(ServerOp::Info { info }))
128                } else {
129                    Err(DecoderError::InvalidCommand)
130                };
131            }
132            DecoderStatus::Headers { header_len, .. } => {
133                if read_buf.len() < *header_len {
134                    return Ok(None);
135                }
136
137                decode_headers(crlf, read_buf, status)?;
138            }
139            DecoderStatus::Payload { payload_len, .. } => {
140                if read_buf.len() < *payload_len + "\r\n".len() {
141                    return Ok(None);
142                }
143
144                let DecoderStatus::Payload {
145                    subscription_id,
146                    subject,
147                    reply_subject,
148                    status_code,
149                    headers,
150                    payload_len,
151                } = mem::replace(status, DecoderStatus::ControlLine { last_bytes_read: 0 })
152                else {
153                    unreachable!()
154                };
155
156                let payload = read_buf.split_to(payload_len);
157                read_buf.advance("\r\n".len());
158                let message = ServerMessage {
159                    status_code,
160                    subscription_id,
161                    base: MessageBase {
162                        subject,
163                        reply_subject,
164                        headers,
165                        payload,
166                    },
167                };
168                return Ok(Some(ServerOp::Message { message }));
169            }
170            DecoderStatus::Poisoned => return Err(DecoderError::Poisoned),
171        }
172    }
173}
174
175fn decode_msg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
176    control_line.advance("MSG ".len());
177
178    let mut chunks = util::split_spaces(control_line);
179    let (subject, subscription_id, reply_subject, payload_len) = match (
180        chunks.next(),
181        chunks.next(),
182        chunks.next(),
183        chunks.next(),
184        chunks.next(),
185    ) {
186        (Some(subject), Some(subscription_id), Some(reply_subject), Some(payload_len), None) => {
187            (subject, subscription_id, Some(reply_subject), payload_len)
188        }
189        (Some(subject), Some(subscription_id), Some(payload_len), None, None) => {
190            (subject, subscription_id, None, payload_len)
191        }
192        _ => return Err(DecoderError::InvalidMsgArgsCount),
193    };
194    let subject = Subject::from_dangerous_value(
195        subject
196            .try_into()
197            .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
198    );
199    let subscription_id =
200        SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
201    let reply_subject = reply_subject
202        .map(|reply_subject| {
203            ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
204        })
205        .transpose()?
206        .map(Subject::from_dangerous_value);
207    let payload_len =
208        util::parse_usize(&payload_len).map_err(DecoderError::InvalidPayloadLength)?;
209    Ok(DecoderStatus::Payload {
210        subscription_id,
211        subject,
212        reply_subject,
213        status_code: None,
214        headers: HeaderMap::new(),
215        payload_len,
216    })
217}
218
219fn decode_hmsg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
220    control_line.advance("HMSG ".len());
221    let mut chunks = util::split_spaces(control_line);
222
223    let (subject, subscription_id, reply_subject, header_len, total_len) = match (
224        chunks.next(),
225        chunks.next(),
226        chunks.next(),
227        chunks.next(),
228        chunks.next(),
229        chunks.next(),
230    ) {
231        (
232            Some(subject),
233            Some(subscription_id),
234            Some(reply_to),
235            Some(header_len),
236            Some(total_len),
237            None,
238        ) => (
239            subject,
240            subscription_id,
241            Some(reply_to),
242            header_len,
243            total_len,
244        ),
245        (Some(subject), Some(subscription_id), Some(header_len), Some(total_len), None, None) => {
246            (subject, subscription_id, None, header_len, total_len)
247        }
248        _ => return Err(DecoderError::InvalidHmsgArgsCount),
249    };
250
251    let subject = Subject::from_dangerous_value(
252        subject
253            .try_into()
254            .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
255    );
256    let subscription_id =
257        SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
258    let reply_subject = reply_subject
259        .map(|reply_subject| {
260            ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
261        })
262        .transpose()?
263        .map(Subject::from_dangerous_value);
264    let header_len = util::parse_usize(&header_len).map_err(DecoderError::InvalidHeaderLength)?;
265    let total_len = util::parse_usize(&total_len).map_err(DecoderError::InvalidPayloadLength)?;
266
267    let payload_len = total_len
268        .checked_sub(header_len)
269        .ok_or(DecoderError::InvalidTotalLength)?;
270
271    Ok(DecoderStatus::Headers {
272        subscription_id,
273        subject,
274        reply_subject,
275        header_len,
276        payload_len,
277    })
278}
279
280fn decode_headers(
281    crlf: &CrlfFinder,
282    read_buf: &mut impl BytesLike,
283    status: &mut DecoderStatus,
284) -> Result<(), DecoderError> {
285    let DecoderStatus::Headers {
286        subscription_id,
287        subject,
288        reply_subject,
289        header_len,
290        payload_len,
291    } = mem::replace(status, DecoderStatus::Poisoned)
292    else {
293        unreachable!()
294    };
295
296    let header = read_buf.split_to(header_len);
297    let mut lines = util::lines_iter(crlf, header);
298    let head = lines.next().ok_or(DecoderError::MissingHead)?;
299    let head = head
300        .strip_prefix(b"NATS/1.0")
301        .ok_or(DecoderError::InvalidHead)?;
302    let status_code = if let Some(status_code) = head.get(1..4) {
303        Some(StatusCode::from_ascii_bytes(status_code).map_err(DecoderError::StatusCode)?)
304    } else {
305        None
306    };
307
308    let headers = lines
309        .filter(|line| !line.is_empty())
310        .map(|mut line| {
311            let i = memchr::memchr(b':', &line).ok_or(DecoderError::InvalidHeaderLine)?;
312
313            let name = line.split_to(i);
314            line.advance(":".len());
315            if line[0].is_ascii_whitespace() {
316                // The fact that this is allowed sounds like BS to me
317                line.advance(1);
318            }
319            let value = line;
320
321            let name = HeaderName::try_from(
322                ByteString::try_from(name).map_err(|_| DecoderError::HeaderNameInvalidUtf8)?,
323            )
324            .map_err(DecoderError::HeaderName)?;
325            let value = HeaderValue::try_from(
326                ByteString::try_from(value).map_err(|_| DecoderError::HeaderValueInvalidUtf8)?,
327            )
328            .map_err(DecoderError::HeaderValue)?;
329            Ok((name, value))
330        })
331        .collect::<Result<_, _>>()?;
332
333    *status = DecoderStatus::Payload {
334        subscription_id,
335        subject,
336        reply_subject,
337        status_code,
338        headers,
339        payload_len,
340    };
341    Ok(())
342}
343
344#[derive(Debug, thiserror::Error)]
345pub enum DecoderError {
346    #[error("The head exceeded the maximum head length (len {len} maximum {MAX_HEAD_LEN}")]
347    HeadTooLong { len: usize },
348    #[error("Invalid command")]
349    InvalidCommand,
350    #[error("MSG command has an unexpected number of arguments")]
351    InvalidMsgArgsCount,
352    #[error("HMSG command has an unexpected number of arguments")]
353    InvalidHmsgArgsCount,
354    #[error("The subject isn't valid utf-8")]
355    SubjectInvalidUtf8,
356    #[error("The reply subject isn't valid utf-8")]
357    ReplySubjectInvalidUtf8,
358    #[error("Couldn't parse the Subscription ID")]
359    SubscriptionId(#[source] ParseUintError),
360    #[error("Couldn't parse the length of the header")]
361    InvalidHeaderLength(#[source] ParseUintError),
362    #[error("Couldn't parse the length of the payload")]
363    InvalidPayloadLength(#[source] ParseUintError),
364    #[error("The total length is greater than the header length")]
365    InvalidTotalLength,
366    #[error("HMSG is missing head")]
367    MissingHead,
368    #[error("HMSG has an invalid head")]
369    InvalidHead,
370    #[error("HMSG header line is missing ': '")]
371    InvalidHeaderLine,
372    #[error("Couldn't parse the status code")]
373    StatusCode(#[source] StatusCodeError),
374    #[error("The header name isn't valid utf-8")]
375    HeaderNameInvalidUtf8,
376    #[error("The header name coouldn't be parsed")]
377    HeaderName(#[source] HeaderNameValidateError),
378    #[error("The header value isn't valid utf-8")]
379    HeaderValueInvalidUtf8,
380    #[error("The header value coouldn't be parsed")]
381    HeaderValue(#[source] HeaderValueValidateError),
382    #[error("INFO command JSON payload couldn't be deserialized")]
383    InvalidInfo(#[source] serde_json::Error),
384    #[error("-ERR command message couldn't be deserialized")]
385    InvalidErrorMessage,
386    #[error("The decoder was poisoned")]
387    Poisoned,
388}