watermelon_proto/proto/decoder/
mod.rs

1use core::{mem, ops::Deref};
2
3use bytes::{Buf, Bytes, BytesMut};
4use bytestring::ByteString;
5
6use crate::{
7    MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId,
8    error::ServerError,
9    headers::{
10        HeaderMap, HeaderName, HeaderValue,
11        error::{HeaderNameValidateError, HeaderValueValidateError},
12    },
13    status_code::StatusCodeError,
14    util::{self, ParseUintError},
15};
16
17pub use self::framed::{FrameDecoderError, decode_frame};
18pub use self::stream::StreamDecoder;
19
20use super::ServerOp;
21
22mod framed;
23mod stream;
24
25const MAX_HEAD_LEN: usize = 16 * 1024;
26
27#[derive(Debug)]
28pub(super) enum DecoderStatus {
29    ControlLine {
30        last_bytes_read: usize,
31    },
32    Headers {
33        subscription_id: SubscriptionId,
34        subject: Subject,
35        reply_subject: Option<Subject>,
36        header_len: usize,
37        payload_len: usize,
38    },
39    Payload {
40        subscription_id: SubscriptionId,
41        subject: Subject,
42        reply_subject: Option<Subject>,
43        status_code: Option<StatusCode>,
44        headers: HeaderMap,
45        payload_len: usize,
46    },
47    Poisoned,
48}
49
50pub(super) trait BytesLike: Buf + Deref<Target = [u8]> {
51    fn len(&self) -> usize {
52        Buf::remaining(self)
53    }
54
55    fn split_to(&mut self, at: usize) -> Bytes {
56        self.copy_to_bytes(at)
57    }
58}
59
60impl BytesLike for Bytes {}
61impl BytesLike for BytesMut {}
62
63pub(super) fn decode(
64    status: &mut DecoderStatus,
65    read_buf: &mut impl BytesLike,
66) -> Result<Option<ServerOp>, DecoderError> {
67    loop {
68        match status {
69            DecoderStatus::ControlLine { last_bytes_read } => {
70                if read_buf.starts_with(b"+OK\r\n") {
71                    // Fast path for handling `+OK`
72                    debug_assert_eq!(
73                        *last_bytes_read, 0,
74                        "we shouldn't have handled any bytes before"
75                    );
76                    read_buf.advance("+OK\r\n".len());
77                    return Ok(Some(ServerOp::Success));
78                }
79
80                if *last_bytes_read == read_buf.len() {
81                    // No progress has been made
82                    return Ok(None);
83                }
84
85                let Some(control_line_len) = memchr::memmem::find(read_buf, b"\r\n") else {
86                    return if read_buf.len() < MAX_HEAD_LEN {
87                        *last_bytes_read = read_buf.len();
88                        Ok(None)
89                    } else {
90                        Err(DecoderError::HeadTooLong {
91                            len: read_buf.len(),
92                        })
93                    };
94                };
95
96                let mut control_line = read_buf.split_to(control_line_len + "\r\n".len());
97                control_line.truncate(control_line.len() - 2);
98
99                return if control_line.starts_with(b"MSG ") {
100                    *status = decode_msg(control_line)?;
101                    continue;
102                } else if control_line.starts_with(b"HMSG ") {
103                    *status = decode_hmsg(control_line)?;
104                    continue;
105                } else if control_line.starts_with(b"PING") {
106                    Ok(Some(ServerOp::Ping))
107                } else if control_line.starts_with(b"PONG") {
108                    Ok(Some(ServerOp::Pong))
109                } else if control_line.starts_with(b"+OK") {
110                    // Slow path for handling `+OK`
111                    Ok(Some(ServerOp::Success))
112                } else if control_line.starts_with(b"-ERR ") {
113                    control_line.advance("-ERR ".len());
114                    if !control_line.starts_with(b"'") || !control_line.ends_with(b"'") {
115                        return Err(DecoderError::InvalidErrorMessage);
116                    }
117
118                    control_line.advance(1);
119                    control_line.truncate(control_line.len() - 1);
120                    let raw_message = ByteString::try_from(control_line)
121                        .map_err(|_| DecoderError::InvalidErrorMessage)?;
122                    let error = ServerError::parse(raw_message);
123                    Ok(Some(ServerOp::Error { error }))
124                } else if let Some(info) = control_line.strip_prefix(b"INFO ") {
125                    let info = serde_json::from_slice(info).map_err(DecoderError::InvalidInfo)?;
126                    Ok(Some(ServerOp::Info { info }))
127                } else {
128                    Err(DecoderError::InvalidCommand)
129                };
130            }
131            DecoderStatus::Headers { header_len, .. } => {
132                if read_buf.len() < *header_len {
133                    return Ok(None);
134                }
135
136                decode_headers(read_buf, status)?;
137            }
138            DecoderStatus::Payload { payload_len, .. } => {
139                if read_buf.len() < *payload_len + "\r\n".len() {
140                    return Ok(None);
141                }
142
143                let DecoderStatus::Payload {
144                    subscription_id,
145                    subject,
146                    reply_subject,
147                    status_code,
148                    headers,
149                    payload_len,
150                } = mem::replace(status, DecoderStatus::ControlLine { last_bytes_read: 0 })
151                else {
152                    unreachable!()
153                };
154
155                let payload = read_buf.split_to(payload_len);
156                read_buf.advance("\r\n".len());
157                let message = ServerMessage {
158                    status_code,
159                    subscription_id,
160                    base: MessageBase {
161                        subject,
162                        reply_subject,
163                        headers,
164                        payload,
165                    },
166                };
167                return Ok(Some(ServerOp::Message { message }));
168            }
169            DecoderStatus::Poisoned => return Err(DecoderError::Poisoned),
170        }
171    }
172}
173
174fn decode_msg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
175    control_line.advance("MSG ".len());
176
177    let mut chunks = util::split_spaces(control_line);
178    let (subject, subscription_id, reply_subject, payload_len) = match (
179        chunks.next(),
180        chunks.next(),
181        chunks.next(),
182        chunks.next(),
183        chunks.next(),
184    ) {
185        (Some(subject), Some(subscription_id), Some(reply_subject), Some(payload_len), None) => {
186            (subject, subscription_id, Some(reply_subject), payload_len)
187        }
188        (Some(subject), Some(subscription_id), Some(payload_len), None, None) => {
189            (subject, subscription_id, None, payload_len)
190        }
191        _ => return Err(DecoderError::InvalidMsgArgsCount),
192    };
193    let subject = Subject::from_dangerous_value(
194        subject
195            .try_into()
196            .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
197    );
198    let subscription_id =
199        SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
200    let reply_subject = reply_subject
201        .map(|reply_subject| {
202            ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
203        })
204        .transpose()?
205        .map(Subject::from_dangerous_value);
206    let payload_len =
207        util::parse_usize(&payload_len).map_err(DecoderError::InvalidPayloadLength)?;
208    Ok(DecoderStatus::Payload {
209        subscription_id,
210        subject,
211        reply_subject,
212        status_code: None,
213        headers: HeaderMap::new(),
214        payload_len,
215    })
216}
217
218fn decode_hmsg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
219    control_line.advance("HMSG ".len());
220    let mut chunks = util::split_spaces(control_line);
221
222    let (subject, subscription_id, reply_subject, header_len, total_len) = match (
223        chunks.next(),
224        chunks.next(),
225        chunks.next(),
226        chunks.next(),
227        chunks.next(),
228        chunks.next(),
229    ) {
230        (
231            Some(subject),
232            Some(subscription_id),
233            Some(reply_to),
234            Some(header_len),
235            Some(total_len),
236            None,
237        ) => (
238            subject,
239            subscription_id,
240            Some(reply_to),
241            header_len,
242            total_len,
243        ),
244        (Some(subject), Some(subscription_id), Some(header_len), Some(total_len), None, None) => {
245            (subject, subscription_id, None, header_len, total_len)
246        }
247        _ => return Err(DecoderError::InvalidHmsgArgsCount),
248    };
249
250    let subject = Subject::from_dangerous_value(
251        subject
252            .try_into()
253            .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
254    );
255    let subscription_id =
256        SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
257    let reply_subject = reply_subject
258        .map(|reply_subject| {
259            ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
260        })
261        .transpose()?
262        .map(Subject::from_dangerous_value);
263    let header_len = util::parse_usize(&header_len).map_err(DecoderError::InvalidHeaderLength)?;
264    let total_len = util::parse_usize(&total_len).map_err(DecoderError::InvalidPayloadLength)?;
265
266    let payload_len = total_len
267        .checked_sub(header_len)
268        .ok_or(DecoderError::InvalidTotalLength)?;
269
270    Ok(DecoderStatus::Headers {
271        subscription_id,
272        subject,
273        reply_subject,
274        header_len,
275        payload_len,
276    })
277}
278
279fn decode_headers(
280    read_buf: &mut impl BytesLike,
281    status: &mut DecoderStatus,
282) -> Result<(), DecoderError> {
283    let DecoderStatus::Headers {
284        subscription_id,
285        subject,
286        reply_subject,
287        header_len,
288        payload_len,
289    } = mem::replace(status, DecoderStatus::Poisoned)
290    else {
291        unreachable!()
292    };
293
294    let header = read_buf.split_to(header_len);
295    let mut lines = util::lines_iter(header);
296    let head = lines.next().ok_or(DecoderError::MissingHead)?;
297    let head = head
298        .strip_prefix(b"NATS/1.0")
299        .ok_or(DecoderError::InvalidHead)?;
300    let status_code = if head.len() >= 4 {
301        Some(StatusCode::from_ascii_bytes(&head[1..4]).map_err(DecoderError::StatusCode)?)
302    } else {
303        None
304    };
305
306    let headers = lines
307        .filter(|line| !line.is_empty())
308        .map(|mut line| {
309            let i = memchr::memchr(b':', &line).ok_or(DecoderError::InvalidHeaderLine)?;
310
311            let name = line.split_to(i);
312            line.advance(":".len());
313            if line[0].is_ascii_whitespace() {
314                // The fact that this is allowed sounds like BS to me
315                line.advance(1);
316            }
317            let value = line;
318
319            let name = HeaderName::try_from(
320                ByteString::try_from(name).map_err(|_| DecoderError::HeaderNameInvalidUtf8)?,
321            )
322            .map_err(DecoderError::HeaderName)?;
323            let value = HeaderValue::try_from(
324                ByteString::try_from(value).map_err(|_| DecoderError::HeaderValueInvalidUtf8)?,
325            )
326            .map_err(DecoderError::HeaderValue)?;
327            Ok((name, value))
328        })
329        .collect::<Result<_, _>>()?;
330
331    *status = DecoderStatus::Payload {
332        subscription_id,
333        subject,
334        reply_subject,
335        status_code,
336        headers,
337        payload_len,
338    };
339    Ok(())
340}
341
342#[derive(Debug, thiserror::Error)]
343pub enum DecoderError {
344    #[error("The head exceeded the maximum head length (len {len} maximum {MAX_HEAD_LEN}")]
345    HeadTooLong { len: usize },
346    #[error("Invalid command")]
347    InvalidCommand,
348    #[error("MSG command has an unexpected number of arguments")]
349    InvalidMsgArgsCount,
350    #[error("HMSG command has an unexpected number of arguments")]
351    InvalidHmsgArgsCount,
352    #[error("The subject isn't valid utf-8")]
353    SubjectInvalidUtf8,
354    #[error("The reply subject isn't valid utf-8")]
355    ReplySubjectInvalidUtf8,
356    #[error("Couldn't parse the Subscription ID")]
357    SubscriptionId(#[source] ParseUintError),
358    #[error("Couldn't parse the length of the header")]
359    InvalidHeaderLength(#[source] ParseUintError),
360    #[error("Couldn't parse the length of the payload")]
361    InvalidPayloadLength(#[source] ParseUintError),
362    #[error("The total length is greater than the header length")]
363    InvalidTotalLength,
364    #[error("HMSG is missing head")]
365    MissingHead,
366    #[error("HMSG has an invalid head")]
367    InvalidHead,
368    #[error("HMSG header line is missing ': '")]
369    InvalidHeaderLine,
370    #[error("Couldn't parse the status code")]
371    StatusCode(#[source] StatusCodeError),
372    #[error("The header name isn't valid utf-8")]
373    HeaderNameInvalidUtf8,
374    #[error("The header name coouldn't be parsed")]
375    HeaderName(#[source] HeaderNameValidateError),
376    #[error("The header value isn't valid utf-8")]
377    HeaderValueInvalidUtf8,
378    #[error("The header value coouldn't be parsed")]
379    HeaderValue(#[source] HeaderValueValidateError),
380    #[error("INFO command JSON payload couldn't be deserialized")]
381    InvalidInfo(#[source] serde_json::Error),
382    #[error("-ERR command message couldn't be deserialized")]
383    InvalidErrorMessage,
384    #[error("The decoder was poisoned")]
385    Poisoned,
386}