1use core::{mem, ops::Deref};
2
3use bytes::{Buf, Bytes, BytesMut};
4use bytestring::ByteString;
5
6use crate::{
7 MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId,
8 error::ServerError,
9 headers::{
10 HeaderMap, HeaderName, HeaderValue,
11 error::{HeaderNameValidateError, HeaderValueValidateError},
12 },
13 status_code::StatusCodeError,
14 util::{self, ParseUintError},
15};
16
17pub use self::framed::{FrameDecoderError, decode_frame};
18pub use self::stream::StreamDecoder;
19
20use super::ServerOp;
21
22mod framed;
23mod stream;
24
25const MAX_HEAD_LEN: usize = 16 * 1024;
26
27#[derive(Debug)]
28pub(super) enum DecoderStatus {
29 ControlLine {
30 last_bytes_read: usize,
31 },
32 Headers {
33 subscription_id: SubscriptionId,
34 subject: Subject,
35 reply_subject: Option<Subject>,
36 header_len: usize,
37 payload_len: usize,
38 },
39 Payload {
40 subscription_id: SubscriptionId,
41 subject: Subject,
42 reply_subject: Option<Subject>,
43 status_code: Option<StatusCode>,
44 headers: HeaderMap,
45 payload_len: usize,
46 },
47 Poisoned,
48}
49
50pub(super) trait BytesLike: Buf + Deref<Target = [u8]> {
51 fn len(&self) -> usize {
52 Buf::remaining(self)
53 }
54
55 fn split_to(&mut self, at: usize) -> Bytes {
56 self.copy_to_bytes(at)
57 }
58}
59
60impl BytesLike for Bytes {}
61impl BytesLike for BytesMut {}
62
63pub(super) fn decode(
64 status: &mut DecoderStatus,
65 read_buf: &mut impl BytesLike,
66) -> Result<Option<ServerOp>, DecoderError> {
67 loop {
68 match status {
69 DecoderStatus::ControlLine { last_bytes_read } => {
70 if read_buf.starts_with(b"+OK\r\n") {
71 debug_assert_eq!(
73 *last_bytes_read, 0,
74 "we shouldn't have handled any bytes before"
75 );
76 read_buf.advance("+OK\r\n".len());
77 return Ok(Some(ServerOp::Success));
78 }
79
80 if *last_bytes_read == read_buf.len() {
81 return Ok(None);
83 }
84
85 let Some(control_line_len) = memchr::memmem::find(read_buf, b"\r\n") else {
86 return if read_buf.len() < MAX_HEAD_LEN {
87 *last_bytes_read = read_buf.len();
88 Ok(None)
89 } else {
90 Err(DecoderError::HeadTooLong {
91 len: read_buf.len(),
92 })
93 };
94 };
95
96 let mut control_line = read_buf.split_to(control_line_len + "\r\n".len());
97 control_line.truncate(control_line.len() - 2);
98
99 return if control_line.starts_with(b"MSG ") {
100 *status = decode_msg(control_line)?;
101 continue;
102 } else if control_line.starts_with(b"HMSG ") {
103 *status = decode_hmsg(control_line)?;
104 continue;
105 } else if control_line.starts_with(b"PING") {
106 Ok(Some(ServerOp::Ping))
107 } else if control_line.starts_with(b"PONG") {
108 Ok(Some(ServerOp::Pong))
109 } else if control_line.starts_with(b"+OK") {
110 Ok(Some(ServerOp::Success))
112 } else if control_line.starts_with(b"-ERR ") {
113 control_line.advance("-ERR ".len());
114 if !control_line.starts_with(b"'") || !control_line.ends_with(b"'") {
115 return Err(DecoderError::InvalidErrorMessage);
116 }
117
118 control_line.advance(1);
119 control_line.truncate(control_line.len() - 1);
120 let raw_message = ByteString::try_from(control_line)
121 .map_err(|_| DecoderError::InvalidErrorMessage)?;
122 let error = ServerError::parse(raw_message);
123 Ok(Some(ServerOp::Error { error }))
124 } else if let Some(info) = control_line.strip_prefix(b"INFO ") {
125 let info = serde_json::from_slice(info).map_err(DecoderError::InvalidInfo)?;
126 Ok(Some(ServerOp::Info { info }))
127 } else {
128 Err(DecoderError::InvalidCommand)
129 };
130 }
131 DecoderStatus::Headers { header_len, .. } => {
132 if read_buf.len() < *header_len {
133 return Ok(None);
134 }
135
136 decode_headers(read_buf, status)?;
137 }
138 DecoderStatus::Payload { payload_len, .. } => {
139 if read_buf.len() < *payload_len + "\r\n".len() {
140 return Ok(None);
141 }
142
143 let DecoderStatus::Payload {
144 subscription_id,
145 subject,
146 reply_subject,
147 status_code,
148 headers,
149 payload_len,
150 } = mem::replace(status, DecoderStatus::ControlLine { last_bytes_read: 0 })
151 else {
152 unreachable!()
153 };
154
155 let payload = read_buf.split_to(payload_len);
156 read_buf.advance("\r\n".len());
157 let message = ServerMessage {
158 status_code,
159 subscription_id,
160 base: MessageBase {
161 subject,
162 reply_subject,
163 headers,
164 payload,
165 },
166 };
167 return Ok(Some(ServerOp::Message { message }));
168 }
169 DecoderStatus::Poisoned => return Err(DecoderError::Poisoned),
170 }
171 }
172}
173
174fn decode_msg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
175 control_line.advance("MSG ".len());
176
177 let mut chunks = util::split_spaces(control_line);
178 let (subject, subscription_id, reply_subject, payload_len) = match (
179 chunks.next(),
180 chunks.next(),
181 chunks.next(),
182 chunks.next(),
183 chunks.next(),
184 ) {
185 (Some(subject), Some(subscription_id), Some(reply_subject), Some(payload_len), None) => {
186 (subject, subscription_id, Some(reply_subject), payload_len)
187 }
188 (Some(subject), Some(subscription_id), Some(payload_len), None, None) => {
189 (subject, subscription_id, None, payload_len)
190 }
191 _ => return Err(DecoderError::InvalidMsgArgsCount),
192 };
193 let subject = Subject::from_dangerous_value(
194 subject
195 .try_into()
196 .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
197 );
198 let subscription_id =
199 SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
200 let reply_subject = reply_subject
201 .map(|reply_subject| {
202 ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
203 })
204 .transpose()?
205 .map(Subject::from_dangerous_value);
206 let payload_len =
207 util::parse_usize(&payload_len).map_err(DecoderError::InvalidPayloadLength)?;
208 Ok(DecoderStatus::Payload {
209 subscription_id,
210 subject,
211 reply_subject,
212 status_code: None,
213 headers: HeaderMap::new(),
214 payload_len,
215 })
216}
217
218fn decode_hmsg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
219 control_line.advance("HMSG ".len());
220 let mut chunks = util::split_spaces(control_line);
221
222 let (subject, subscription_id, reply_subject, header_len, total_len) = match (
223 chunks.next(),
224 chunks.next(),
225 chunks.next(),
226 chunks.next(),
227 chunks.next(),
228 chunks.next(),
229 ) {
230 (
231 Some(subject),
232 Some(subscription_id),
233 Some(reply_to),
234 Some(header_len),
235 Some(total_len),
236 None,
237 ) => (
238 subject,
239 subscription_id,
240 Some(reply_to),
241 header_len,
242 total_len,
243 ),
244 (Some(subject), Some(subscription_id), Some(header_len), Some(total_len), None, None) => {
245 (subject, subscription_id, None, header_len, total_len)
246 }
247 _ => return Err(DecoderError::InvalidHmsgArgsCount),
248 };
249
250 let subject = Subject::from_dangerous_value(
251 subject
252 .try_into()
253 .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
254 );
255 let subscription_id =
256 SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
257 let reply_subject = reply_subject
258 .map(|reply_subject| {
259 ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
260 })
261 .transpose()?
262 .map(Subject::from_dangerous_value);
263 let header_len = util::parse_usize(&header_len).map_err(DecoderError::InvalidHeaderLength)?;
264 let total_len = util::parse_usize(&total_len).map_err(DecoderError::InvalidPayloadLength)?;
265
266 let payload_len = total_len
267 .checked_sub(header_len)
268 .ok_or(DecoderError::InvalidTotalLength)?;
269
270 Ok(DecoderStatus::Headers {
271 subscription_id,
272 subject,
273 reply_subject,
274 header_len,
275 payload_len,
276 })
277}
278
279fn decode_headers(
280 read_buf: &mut impl BytesLike,
281 status: &mut DecoderStatus,
282) -> Result<(), DecoderError> {
283 let DecoderStatus::Headers {
284 subscription_id,
285 subject,
286 reply_subject,
287 header_len,
288 payload_len,
289 } = mem::replace(status, DecoderStatus::Poisoned)
290 else {
291 unreachable!()
292 };
293
294 let header = read_buf.split_to(header_len);
295 let mut lines = util::lines_iter(header);
296 let head = lines.next().ok_or(DecoderError::MissingHead)?;
297 let head = head
298 .strip_prefix(b"NATS/1.0")
299 .ok_or(DecoderError::InvalidHead)?;
300 let status_code = if head.len() >= 4 {
301 Some(StatusCode::from_ascii_bytes(&head[1..4]).map_err(DecoderError::StatusCode)?)
302 } else {
303 None
304 };
305
306 let headers = lines
307 .filter(|line| !line.is_empty())
308 .map(|mut line| {
309 let i = memchr::memchr(b':', &line).ok_or(DecoderError::InvalidHeaderLine)?;
310
311 let name = line.split_to(i);
312 line.advance(":".len());
313 if line[0].is_ascii_whitespace() {
314 line.advance(1);
316 }
317 let value = line;
318
319 let name = HeaderName::try_from(
320 ByteString::try_from(name).map_err(|_| DecoderError::HeaderNameInvalidUtf8)?,
321 )
322 .map_err(DecoderError::HeaderName)?;
323 let value = HeaderValue::try_from(
324 ByteString::try_from(value).map_err(|_| DecoderError::HeaderValueInvalidUtf8)?,
325 )
326 .map_err(DecoderError::HeaderValue)?;
327 Ok((name, value))
328 })
329 .collect::<Result<_, _>>()?;
330
331 *status = DecoderStatus::Payload {
332 subscription_id,
333 subject,
334 reply_subject,
335 status_code,
336 headers,
337 payload_len,
338 };
339 Ok(())
340}
341
342#[derive(Debug, thiserror::Error)]
343pub enum DecoderError {
344 #[error("The head exceeded the maximum head length (len {len} maximum {MAX_HEAD_LEN}")]
345 HeadTooLong { len: usize },
346 #[error("Invalid command")]
347 InvalidCommand,
348 #[error("MSG command has an unexpected number of arguments")]
349 InvalidMsgArgsCount,
350 #[error("HMSG command has an unexpected number of arguments")]
351 InvalidHmsgArgsCount,
352 #[error("The subject isn't valid utf-8")]
353 SubjectInvalidUtf8,
354 #[error("The reply subject isn't valid utf-8")]
355 ReplySubjectInvalidUtf8,
356 #[error("Couldn't parse the Subscription ID")]
357 SubscriptionId(#[source] ParseUintError),
358 #[error("Couldn't parse the length of the header")]
359 InvalidHeaderLength(#[source] ParseUintError),
360 #[error("Couldn't parse the length of the payload")]
361 InvalidPayloadLength(#[source] ParseUintError),
362 #[error("The total length is greater than the header length")]
363 InvalidTotalLength,
364 #[error("HMSG is missing head")]
365 MissingHead,
366 #[error("HMSG has an invalid head")]
367 InvalidHead,
368 #[error("HMSG header line is missing ': '")]
369 InvalidHeaderLine,
370 #[error("Couldn't parse the status code")]
371 StatusCode(#[source] StatusCodeError),
372 #[error("The header name isn't valid utf-8")]
373 HeaderNameInvalidUtf8,
374 #[error("The header name coouldn't be parsed")]
375 HeaderName(#[source] HeaderNameValidateError),
376 #[error("The header value isn't valid utf-8")]
377 HeaderValueInvalidUtf8,
378 #[error("The header value coouldn't be parsed")]
379 HeaderValue(#[source] HeaderValueValidateError),
380 #[error("INFO command JSON payload couldn't be deserialized")]
381 InvalidInfo(#[source] serde_json::Error),
382 #[error("-ERR command message couldn't be deserialized")]
383 InvalidErrorMessage,
384 #[error("The decoder was poisoned")]
385 Poisoned,
386}