1use core::{mem, ops::Deref};
2
3use bytes::{Buf, Bytes, BytesMut};
4use bytestring::ByteString;
5
6use crate::{
7 MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId,
8 error::ServerError,
9 headers::{
10 HeaderMap, HeaderName, HeaderValue,
11 error::{HeaderNameValidateError, HeaderValueValidateError},
12 },
13 status_code::StatusCodeError,
14 util::{self, CrlfFinder, ParseUintError},
15};
16
17pub use self::framed::{FrameDecoderError, decode_frame};
18pub use self::stream::StreamDecoder;
19
20use super::ServerOp;
21
22mod framed;
23mod stream;
24
25const MAX_HEAD_LEN: usize = 16 * 1024;
26
27#[derive(Debug)]
28pub(super) enum DecoderStatus {
29 ControlLine {
30 last_bytes_read: usize,
31 },
32 Headers {
33 subscription_id: SubscriptionId,
34 subject: Subject,
35 reply_subject: Option<Subject>,
36 header_len: usize,
37 payload_len: usize,
38 },
39 Payload {
40 subscription_id: SubscriptionId,
41 subject: Subject,
42 reply_subject: Option<Subject>,
43 status_code: Option<StatusCode>,
44 headers: HeaderMap,
45 payload_len: usize,
46 },
47 Poisoned,
48}
49
50pub(super) trait BytesLike: Buf + Deref<Target = [u8]> {
51 fn len(&self) -> usize {
52 Buf::remaining(self)
53 }
54
55 fn split_to(&mut self, at: usize) -> Bytes {
56 self.copy_to_bytes(at)
57 }
58}
59
60impl BytesLike for Bytes {}
61impl BytesLike for BytesMut {}
62
63pub(super) fn decode(
64 crlf: &CrlfFinder,
65 status: &mut DecoderStatus,
66 read_buf: &mut impl BytesLike,
67) -> Result<Option<ServerOp>, DecoderError> {
68 loop {
69 match status {
70 DecoderStatus::ControlLine { last_bytes_read } => {
71 if read_buf.starts_with(b"+OK\r\n") {
72 debug_assert_eq!(
74 *last_bytes_read, 0,
75 "we shouldn't have handled any bytes before"
76 );
77 read_buf.advance("+OK\r\n".len());
78 return Ok(Some(ServerOp::Success));
79 }
80
81 if *last_bytes_read == read_buf.len() {
82 return Ok(None);
84 }
85
86 let Some(control_line_len) = crlf.find(read_buf) else {
87 return if read_buf.len() < MAX_HEAD_LEN {
88 *last_bytes_read = read_buf.len();
89 Ok(None)
90 } else {
91 Err(DecoderError::HeadTooLong {
92 len: read_buf.len(),
93 })
94 };
95 };
96
97 let mut control_line = read_buf.split_to(control_line_len + "\r\n".len());
98 control_line.truncate(control_line.len() - 2);
99
100 return if control_line.starts_with(b"MSG ") {
101 *status = decode_msg(control_line)?;
102 continue;
103 } else if control_line.starts_with(b"HMSG ") {
104 *status = decode_hmsg(control_line)?;
105 continue;
106 } else if control_line.starts_with(b"PING") {
107 Ok(Some(ServerOp::Ping))
108 } else if control_line.starts_with(b"PONG") {
109 Ok(Some(ServerOp::Pong))
110 } else if control_line.starts_with(b"+OK") {
111 Ok(Some(ServerOp::Success))
113 } else if control_line.starts_with(b"-ERR ") {
114 control_line.advance("-ERR ".len());
115 if !control_line.starts_with(b"'") || !control_line.ends_with(b"'") {
116 return Err(DecoderError::InvalidErrorMessage);
117 }
118
119 control_line.advance(1);
120 control_line.truncate(control_line.len() - 1);
121 let raw_message = ByteString::try_from(control_line)
122 .map_err(|_| DecoderError::InvalidErrorMessage)?;
123 let error = ServerError::parse(raw_message);
124 Ok(Some(ServerOp::Error { error }))
125 } else if let Some(info) = control_line.strip_prefix(b"INFO ") {
126 let info = serde_json::from_slice(info).map_err(DecoderError::InvalidInfo)?;
127 Ok(Some(ServerOp::Info { info }))
128 } else {
129 Err(DecoderError::InvalidCommand)
130 };
131 }
132 DecoderStatus::Headers { header_len, .. } => {
133 if read_buf.len() < *header_len {
134 return Ok(None);
135 }
136
137 decode_headers(crlf, read_buf, status)?;
138 }
139 DecoderStatus::Payload { payload_len, .. } => {
140 if read_buf.len() < *payload_len + "\r\n".len() {
141 return Ok(None);
142 }
143
144 let DecoderStatus::Payload {
145 subscription_id,
146 subject,
147 reply_subject,
148 status_code,
149 headers,
150 payload_len,
151 } = mem::replace(status, DecoderStatus::ControlLine { last_bytes_read: 0 })
152 else {
153 unreachable!()
154 };
155
156 let payload = read_buf.split_to(payload_len);
157 read_buf.advance("\r\n".len());
158 let message = ServerMessage {
159 status_code,
160 subscription_id,
161 base: MessageBase {
162 subject,
163 reply_subject,
164 headers,
165 payload,
166 },
167 };
168 return Ok(Some(ServerOp::Message { message }));
169 }
170 DecoderStatus::Poisoned => return Err(DecoderError::Poisoned),
171 }
172 }
173}
174
175fn decode_msg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
176 control_line.advance("MSG ".len());
177
178 let mut chunks = util::split_spaces(control_line);
179 let (subject, subscription_id, reply_subject, payload_len) = match (
180 chunks.next(),
181 chunks.next(),
182 chunks.next(),
183 chunks.next(),
184 chunks.next(),
185 ) {
186 (Some(subject), Some(subscription_id), Some(reply_subject), Some(payload_len), None) => {
187 (subject, subscription_id, Some(reply_subject), payload_len)
188 }
189 (Some(subject), Some(subscription_id), Some(payload_len), None, None) => {
190 (subject, subscription_id, None, payload_len)
191 }
192 _ => return Err(DecoderError::InvalidMsgArgsCount),
193 };
194 let subject = Subject::from_dangerous_value(
195 subject
196 .try_into()
197 .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
198 );
199 let subscription_id =
200 SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
201 let reply_subject = reply_subject
202 .map(|reply_subject| {
203 ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
204 })
205 .transpose()?
206 .map(Subject::from_dangerous_value);
207 let payload_len =
208 util::parse_usize(&payload_len).map_err(DecoderError::InvalidPayloadLength)?;
209 Ok(DecoderStatus::Payload {
210 subscription_id,
211 subject,
212 reply_subject,
213 status_code: None,
214 headers: HeaderMap::new(),
215 payload_len,
216 })
217}
218
219fn decode_hmsg(mut control_line: Bytes) -> Result<DecoderStatus, DecoderError> {
220 control_line.advance("HMSG ".len());
221 let mut chunks = util::split_spaces(control_line);
222
223 let (subject, subscription_id, reply_subject, header_len, total_len) = match (
224 chunks.next(),
225 chunks.next(),
226 chunks.next(),
227 chunks.next(),
228 chunks.next(),
229 chunks.next(),
230 ) {
231 (
232 Some(subject),
233 Some(subscription_id),
234 Some(reply_to),
235 Some(header_len),
236 Some(total_len),
237 None,
238 ) => (
239 subject,
240 subscription_id,
241 Some(reply_to),
242 header_len,
243 total_len,
244 ),
245 (Some(subject), Some(subscription_id), Some(header_len), Some(total_len), None, None) => {
246 (subject, subscription_id, None, header_len, total_len)
247 }
248 _ => return Err(DecoderError::InvalidHmsgArgsCount),
249 };
250
251 let subject = Subject::from_dangerous_value(
252 subject
253 .try_into()
254 .map_err(|_| DecoderError::SubjectInvalidUtf8)?,
255 );
256 let subscription_id =
257 SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?;
258 let reply_subject = reply_subject
259 .map(|reply_subject| {
260 ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8)
261 })
262 .transpose()?
263 .map(Subject::from_dangerous_value);
264 let header_len = util::parse_usize(&header_len).map_err(DecoderError::InvalidHeaderLength)?;
265 let total_len = util::parse_usize(&total_len).map_err(DecoderError::InvalidPayloadLength)?;
266
267 let payload_len = total_len
268 .checked_sub(header_len)
269 .ok_or(DecoderError::InvalidTotalLength)?;
270
271 Ok(DecoderStatus::Headers {
272 subscription_id,
273 subject,
274 reply_subject,
275 header_len,
276 payload_len,
277 })
278}
279
280fn decode_headers(
281 crlf: &CrlfFinder,
282 read_buf: &mut impl BytesLike,
283 status: &mut DecoderStatus,
284) -> Result<(), DecoderError> {
285 let DecoderStatus::Headers {
286 subscription_id,
287 subject,
288 reply_subject,
289 header_len,
290 payload_len,
291 } = mem::replace(status, DecoderStatus::Poisoned)
292 else {
293 unreachable!()
294 };
295
296 let header = read_buf.split_to(header_len);
297 let mut lines = util::lines_iter(crlf, header);
298 let head = lines.next().ok_or(DecoderError::MissingHead)?;
299 let head = head
300 .strip_prefix(b"NATS/1.0")
301 .ok_or(DecoderError::InvalidHead)?;
302 let status_code = if let Some(status_code) = head.get(1..4) {
303 Some(StatusCode::from_ascii_bytes(status_code).map_err(DecoderError::StatusCode)?)
304 } else {
305 None
306 };
307
308 let headers = lines
309 .filter(|line| !line.is_empty())
310 .map(|mut line| {
311 let i = memchr::memchr(b':', &line).ok_or(DecoderError::InvalidHeaderLine)?;
312
313 let name = line.split_to(i);
314 line.advance(":".len());
315 if line[0].is_ascii_whitespace() {
316 line.advance(1);
318 }
319 let value = line;
320
321 let name = HeaderName::try_from(
322 ByteString::try_from(name).map_err(|_| DecoderError::HeaderNameInvalidUtf8)?,
323 )
324 .map_err(DecoderError::HeaderName)?;
325 let value = HeaderValue::try_from(
326 ByteString::try_from(value).map_err(|_| DecoderError::HeaderValueInvalidUtf8)?,
327 )
328 .map_err(DecoderError::HeaderValue)?;
329 Ok((name, value))
330 })
331 .collect::<Result<_, _>>()?;
332
333 *status = DecoderStatus::Payload {
334 subscription_id,
335 subject,
336 reply_subject,
337 status_code,
338 headers,
339 payload_len,
340 };
341 Ok(())
342}
343
344#[derive(Debug, thiserror::Error)]
345pub enum DecoderError {
346 #[error("The head exceeded the maximum head length (len {len} maximum {MAX_HEAD_LEN}")]
347 HeadTooLong { len: usize },
348 #[error("Invalid command")]
349 InvalidCommand,
350 #[error("MSG command has an unexpected number of arguments")]
351 InvalidMsgArgsCount,
352 #[error("HMSG command has an unexpected number of arguments")]
353 InvalidHmsgArgsCount,
354 #[error("The subject isn't valid utf-8")]
355 SubjectInvalidUtf8,
356 #[error("The reply subject isn't valid utf-8")]
357 ReplySubjectInvalidUtf8,
358 #[error("Couldn't parse the Subscription ID")]
359 SubscriptionId(#[source] ParseUintError),
360 #[error("Couldn't parse the length of the header")]
361 InvalidHeaderLength(#[source] ParseUintError),
362 #[error("Couldn't parse the length of the payload")]
363 InvalidPayloadLength(#[source] ParseUintError),
364 #[error("The total length is greater than the header length")]
365 InvalidTotalLength,
366 #[error("HMSG is missing head")]
367 MissingHead,
368 #[error("HMSG has an invalid head")]
369 InvalidHead,
370 #[error("HMSG header line is missing ': '")]
371 InvalidHeaderLine,
372 #[error("Couldn't parse the status code")]
373 StatusCode(#[source] StatusCodeError),
374 #[error("The header name isn't valid utf-8")]
375 HeaderNameInvalidUtf8,
376 #[error("The header name coouldn't be parsed")]
377 HeaderName(#[source] HeaderNameValidateError),
378 #[error("The header value isn't valid utf-8")]
379 HeaderValueInvalidUtf8,
380 #[error("The header value coouldn't be parsed")]
381 HeaderValue(#[source] HeaderValueValidateError),
382 #[error("INFO command JSON payload couldn't be deserialized")]
383 InvalidInfo(#[source] serde_json::Error),
384 #[error("-ERR command message couldn't be deserialized")]
385 InvalidErrorMessage,
386 #[error("The decoder was poisoned")]
387 Poisoned,
388}