1use core::{fmt, mem};
2
3use std::io;
4
5use tracing::{trace, warn};
6
7use crate::bytes::{Buf, Bytes, BytesMut};
8
9use super::{buf_write::H1BufWrite, error::ProtoError};
10
11#[derive(Clone, Debug, Eq, PartialEq)]
13pub enum TransferCoding {
14 Eof,
16 Corrupted,
18 Length(u64),
20 DecodeChunked(ChunkedState, u64),
22 EncodeChunked,
24 Upgrade,
26}
27
28impl TransferCoding {
29 #[inline]
30 pub const fn eof() -> Self {
31 Self::Eof
32 }
33
34 #[inline]
35 pub const fn length(len: u64) -> Self {
36 Self::Length(len)
37 }
38
39 #[inline]
40 pub const fn decode_chunked() -> Self {
41 Self::DecodeChunked(ChunkedState::Size, 0)
42 }
43
44 #[inline]
45 pub const fn encode_chunked() -> Self {
46 Self::EncodeChunked
47 }
48
49 #[inline]
50 pub const fn upgrade() -> Self {
51 Self::Upgrade
52 }
53
54 #[inline]
57 pub fn is_eof(&self) -> bool {
58 match self {
59 Self::Eof => true,
60 Self::EncodeChunked => unreachable!("TransferCoding can't decide eof state when encoding chunked data"),
61 _ => false,
62 }
63 }
64
65 #[inline]
66 pub fn is_upgrade(&self) -> bool {
67 matches!(self, Self::Upgrade)
68 }
69}
70
71#[derive(Clone, Debug, Eq, PartialEq)]
72pub enum ChunkedState {
73 Size,
74 SizeLws,
75 Extension,
76 SizeLf,
77 Body,
78 BodyCr,
79 BodyLf,
80 Trailer,
81 TrailerLf,
82 EndCr,
83 EndLf,
84 End,
85}
86
87macro_rules! byte (
88 ($rdr:ident) => ({
89 if $rdr.len() > 0 {
90 let b = $rdr[0];
91 $rdr.advance(1);
92 b
93 } else {
94 return Ok(None);
95 }
96 })
97);
98
99impl ChunkedState {
100 pub fn step(&mut self, body: &mut BytesMut, size: &mut u64, buf: &mut Option<Bytes>) -> io::Result<Option<Self>> {
101 match *self {
102 Self::Size => Self::read_size(body, size),
103 Self::SizeLws => Self::read_size_lws(body),
104 Self::Extension => Self::read_extension(body),
105 Self::SizeLf => Self::read_size_lf(body, size),
106 Self::Body => Self::read_body(body, size, buf),
107 Self::BodyCr => Self::read_body_cr(body),
108 Self::BodyLf => Self::read_body_lf(body),
109 Self::Trailer => Self::read_trailer(body),
110 Self::TrailerLf => Self::read_trailer_lf(body),
111 Self::EndCr => Self::read_end_cr(body),
112 Self::EndLf => Self::read_end_lf(body),
113 Self::End => Ok(Some(Self::End)),
114 }
115 }
116
117 fn read_size(rdr: &mut BytesMut, size: &mut u64) -> io::Result<Option<Self>> {
118 macro_rules! or_overflow {
119 ($e:expr) => (
120 match $e {
121 Some(val) => val,
122 None => return Err(io::Error::new(
123 io::ErrorKind::InvalidData,
124 "invalid chunk size: overflow",
125 )),
126 }
127 )
128 }
129
130 let radix = 16;
131 match byte!(rdr) {
132 b @ b'0'..=b'9' => {
133 *size = or_overflow!(size.checked_mul(radix));
134 *size = or_overflow!(size.checked_add((b - b'0') as u64));
135 }
136 b @ b'a'..=b'f' => {
137 *size = or_overflow!(size.checked_mul(radix));
138 *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
139 }
140 b @ b'A'..=b'F' => {
141 *size = or_overflow!(size.checked_mul(radix));
142 *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
143 }
144 b'\t' | b' ' => return Ok(Some(ChunkedState::SizeLws)),
145 b';' => return Ok(Some(ChunkedState::Extension)),
146 b'\r' => return Ok(Some(ChunkedState::SizeLf)),
147 _ => {
148 return Err(io::Error::new(
149 io::ErrorKind::InvalidInput,
150 "Invalid chunk size line: Invalid Size",
151 ));
152 }
153 }
154
155 Ok(Some(ChunkedState::Size))
156 }
157
158 fn read_size_lws(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
159 match byte!(rdr) {
160 b'\t' | b' ' => Ok(Some(Self::SizeLws)),
162 b';' => Ok(Some(Self::Extension)),
163 b'\r' => Ok(Some(Self::SizeLf)),
164 _ => Err(io::Error::new(
165 io::ErrorKind::InvalidInput,
166 "Invalid chunk size linear white space",
167 )),
168 }
169 }
170
171 fn read_extension(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
172 match byte!(rdr) {
173 b'\r' => Ok(Some(Self::SizeLf)),
174 b'\n' => Err(io::Error::new(
175 io::ErrorKind::InvalidData,
176 "invalid chunk extension contains newline",
177 )),
178 _ => Ok(Some(Self::Extension)), }
180 }
181
182 fn read_size_lf(rdr: &mut BytesMut, size: &u64) -> io::Result<Option<Self>> {
183 match byte!(rdr) {
184 b'\n' if *size > 0 => Ok(Some(Self::Body)),
185 b'\n' if *size == 0 => Ok(Some(Self::EndCr)),
186 _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")),
187 }
188 }
189
190 fn read_body(rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option<Bytes>) -> io::Result<Option<Self>> {
191 if rdr.is_empty() {
192 Ok(None)
193 } else {
194 *buf = Some(bounded_split(rem, rdr));
195 if *rem > 0 {
196 Ok(Some(Self::Body))
197 } else {
198 Ok(Some(Self::BodyCr))
199 }
200 }
201 }
202
203 fn read_body_cr(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
204 match byte!(rdr) {
205 b'\r' => Ok(Some(Self::BodyLf)),
206 _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")),
207 }
208 }
209
210 fn read_body_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
211 match byte!(rdr) {
212 b'\n' => Ok(Some(Self::Size)),
213 _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")),
214 }
215 }
216
217 fn read_trailer(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
218 trace!(target: "h1_decode", "read_trailer");
219 match byte!(rdr) {
220 b'\r' => Ok(Some(Self::TrailerLf)),
221 _ => Ok(Some(Self::Trailer)),
222 }
223 }
224
225 fn read_trailer_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
226 match byte!(rdr) {
227 b'\n' => Ok(Some(Self::EndCr)),
228 _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid trailer end LF")),
229 }
230 }
231
232 fn read_end_cr(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
233 match byte!(rdr) {
234 b'\r' => Ok(Some(Self::EndLf)),
235 _ => Ok(Some(Self::Trailer)),
236 }
237 }
238
239 fn read_end_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
240 match byte!(rdr) {
241 b'\n' => Ok(Some(Self::End)),
242 _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")),
243 }
244 }
245}
246
247impl TransferCoding {
248 pub fn try_set(&mut self, other: Self) -> Result<(), ProtoError> {
249 match (&self, &other) {
250 (TransferCoding::Upgrade, TransferCoding::Upgrade) | (_, TransferCoding::Length(0)) => Ok(()),
254 (TransferCoding::Upgrade, _) | (TransferCoding::DecodeChunked(..), _) | (TransferCoding::Length(..), _) => {
257 Err(ProtoError::HeaderName)
258 }
259 _ => {
260 *self = other;
261 Ok(())
262 }
263 }
264 }
265
266 #[inline]
267 pub fn set_eof(&mut self) {
268 *self = Self::Eof;
269 }
270
271 #[inline]
272 pub fn set_corrupted(&mut self) {
273 *self = Self::Corrupted;
274 }
275
276 pub fn encode<W>(&mut self, mut bytes: Bytes, buf: &mut W)
278 where
279 W: H1BufWrite,
280 {
281 if bytes.is_empty() {
285 return;
286 }
287
288 match *self {
289 Self::Upgrade => buf.write_buf_bytes(bytes),
290 Self::EncodeChunked => buf.write_buf_bytes_chunked(bytes),
291 Self::Length(ref mut rem) => {
292 let len = bytes.len() as u64;
293 if *rem >= len {
294 buf.write_buf_bytes(bytes);
295 *rem -= len;
296 } else {
297 let rem = mem::replace(rem, 0u64);
298 buf.write_buf_bytes(bytes.split_to(rem as usize));
299 }
300 }
301 Self::Eof => warn!(target: "h1_encode", "TransferCoding::Eof should not encode response body"),
302 _ => unreachable!(),
303 }
304 }
305
306 pub fn encode_eof<W>(&mut self, buf: &mut W)
308 where
309 W: H1BufWrite,
310 {
311 match *self {
312 Self::Eof | Self::Upgrade | Self::Length(0) => {}
313 Self::EncodeChunked => buf.write_buf_static(b"0\r\n\r\n"),
314 Self::Length(n) => unreachable!("UnexpectedEof for Length Body with {} remaining", n),
315 _ => unreachable!(),
316 }
317 }
318
319 pub fn decode(&mut self, src: &mut BytesMut) -> ChunkResult {
321 match *self {
322 Self::Length(0) | Self::DecodeChunked(ChunkedState::End, _) => {
327 *self = Self::Eof;
328 ChunkResult::OnEof
329 }
330 Self::Eof => ChunkResult::AlreadyEof,
331 Self::Corrupted => ChunkResult::Corrupted,
332 ref _this if src.is_empty() => ChunkResult::InsufficientData,
333 Self::Length(ref mut rem) => ChunkResult::Ok(bounded_split(rem, src)),
334 Self::Upgrade => ChunkResult::Ok(src.split().freeze()),
335 Self::DecodeChunked(ref mut state, ref mut size) => {
336 loop {
337 let mut buf = None;
338 *state = match state.step(src, size, &mut buf) {
340 Ok(Some(state)) => state,
341 Ok(None) => return ChunkResult::InsufficientData,
342 Err(e) => return ChunkResult::Err(e),
343 };
344
345 if matches!(state, ChunkedState::End) {
346 return self.decode(src);
347 }
348
349 if let Some(buf) = buf {
350 return ChunkResult::Ok(buf);
351 }
352 }
353 }
354 _ => unreachable!(),
355 }
356 }
357}
358
359#[derive(Debug)]
360pub enum ChunkResult {
361 Ok(Bytes),
363 Err(io::Error),
365 InsufficientData,
367 OnEof,
369 AlreadyEof,
372 Corrupted,
374}
375
376impl fmt::Display for ChunkResult {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 match *self {
379 Self::Ok(_) => f.write_str("chunked data."),
380 Self::Err(ref e) => fmt::Display::fmt(e, f),
381 Self::InsufficientData => f.write_str("no sufficient data. More input bytes required."),
382 Self::OnEof => f.write_str("coder reached EOF state. no more chunk can be produced."),
383 Self::AlreadyEof => f.write_str("coder already reached EOF state. no more chunk can be produced."),
384 Self::Corrupted => f.write_str("coder corrupted. can not be used anymore."),
385 }
386 }
387}
388
389impl From<io::Error> for ChunkResult {
390 fn from(e: io::Error) -> Self {
391 Self::Err(e)
392 }
393}
394
395fn bounded_split(rem: &mut u64, buf: &mut BytesMut) -> Bytes {
396 let len = buf.len() as u64;
397 if *rem >= len {
398 *rem -= len;
399 buf.split().freeze()
400 } else {
401 let rem = mem::replace(rem, 0);
402 buf.split_to(rem as usize).freeze()
403 }
404}
405
406#[cfg(test)]
407mod test {
408 use crate::util::buffered::WriteBuf;
409
410 use super::*;
411
412 #[test]
413 fn test_read_chunk_size() {
414 use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
415
416 fn read(s: &str) -> u64 {
417 let mut state = ChunkedState::Size;
418 let rdr = &mut BytesMut::from(s);
419 let mut size = 0;
420 loop {
421 let result = state.step(rdr, &mut size, &mut None);
422 state = result.unwrap_or_else(|_| panic!("read_size failed for {s:?}")).unwrap();
423 if state == ChunkedState::Body || state == ChunkedState::EndCr {
424 break;
425 }
426 }
427 size
428 }
429
430 fn read_err(s: &str, expected_err: io::ErrorKind) {
431 let mut state = ChunkedState::Size;
432 let rdr = &mut BytesMut::from(s);
433 let mut size = 0;
434 loop {
435 let result = state.step(rdr, &mut size, &mut None);
436 state = match result {
437 Ok(Some(s)) => s,
438 Ok(None) => return assert_eq!(expected_err, UnexpectedEof),
439 Err(e) => {
440 assert_eq!(
441 expected_err,
442 e.kind(),
443 "Reading {:?}, expected {:?}, but got {:?}",
444 s,
445 expected_err,
446 e.kind()
447 );
448 return;
449 }
450 };
451 if state == ChunkedState::Body || state == ChunkedState::End {
452 panic!("Was Ok. Expected Err for {s:?}");
453 }
454 }
455 }
456
457 assert_eq!(1, read("1\r\n"));
458 assert_eq!(1, read("01\r\n"));
459 assert_eq!(0, read("0\r\n"));
460 assert_eq!(0, read("00\r\n"));
461 assert_eq!(10, read("A\r\n"));
462 assert_eq!(10, read("a\r\n"));
463 assert_eq!(255, read("Ff\r\n"));
464 assert_eq!(255, read("Ff \r\n"));
465 read_err("F\rF", InvalidInput);
467 read_err("F", UnexpectedEof);
468 read_err("X\r\n", InvalidInput);
470 read_err("1X\r\n", InvalidInput);
471 read_err("-\r\n", InvalidInput);
472 read_err("-1\r\n", InvalidInput);
473 assert_eq!(1, read("1;extension\r\n"));
475 assert_eq!(10, read("a;ext name=value\r\n"));
476 assert_eq!(1, read("1;extension;extension2\r\n"));
477 assert_eq!(1, read("1;;; ;\r\n"));
478 assert_eq!(2, read("2; extension...\r\n"));
479 assert_eq!(3, read("3 ; extension=123\r\n"));
480 assert_eq!(3, read("3 ;\r\n"));
481 assert_eq!(3, read("3 ; \r\n"));
482 read_err("1 invalid extension\r\n", InvalidInput);
484 read_err("1 A\r\n", InvalidInput);
485 read_err("1;no CRLF", UnexpectedEof);
486 read_err("1;reject\nnewlines\r\n", InvalidData);
487 read_err("f0000000000000003\r\n", InvalidData);
489 }
490
491 #[test]
492 fn test_read_chunked_single_read() {
493 let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n");
494
495 match TransferCoding::decode_chunked().decode(mock_buf) {
496 ChunkResult::Ok(buf) => {
497 assert_eq!(16, buf.len());
498 let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
499 assert_eq!("1234567890abcdef", &result);
500 }
501 state => panic!("{}", state),
502 }
503 }
504
505 #[test]
506 fn test_read_chunked_trailer_with_missing_lf() {
507 let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n");
508
509 let mut decoder = TransferCoding::decode_chunked();
510
511 match decoder.decode(mock_buf) {
512 ChunkResult::Ok(_) => {}
513 state => panic!("{}", state),
514 }
515
516 match decoder.decode(mock_buf) {
517 ChunkResult::Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidInput),
518 state => panic!("{}", state),
519 }
520 }
521
522 #[test]
523 fn test_read_chunked_after_eof() {
524 let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n\r\n");
525 let mut decoder = TransferCoding::decode_chunked();
526
527 match decoder.decode(mock_buf) {
529 ChunkResult::Ok(buf) => {
530 assert_eq!(16, buf.len());
531 let result = String::from_utf8(buf.as_ref().to_vec()).unwrap();
532 assert_eq!("1234567890abcdef", &result);
533 }
534 state => panic!("{}", state),
535 }
536
537 match decoder.decode(mock_buf) {
539 ChunkResult::OnEof => {}
540 state => panic!("{}", state),
541 }
542
543 match decoder.decode(mock_buf) {
545 ChunkResult::AlreadyEof => {}
546 state => panic!("{}", state),
547 }
548 }
549
550 #[test]
551 fn encode_chunked() {
552 let mut encoder = TransferCoding::encode_chunked();
553 let dst = &mut WriteBuf::<1024>::default();
554
555 let msg1 = Bytes::from("foo bar");
556 encoder.encode(msg1, dst);
557
558 assert_eq!(dst.buf(), b"7\r\nfoo bar\r\n");
559
560 let msg2 = Bytes::from("baz quux herp");
561 encoder.encode(msg2, dst);
562
563 assert_eq!(dst.buf(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
564
565 encoder.encode_eof(dst);
566
567 assert_eq!(dst.buf(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n");
568 }
569
570 #[test]
571 fn encode_length() {
572 let max_len = 8;
573 let mut encoder = TransferCoding::length(max_len as u64);
574
575 let dst = &mut WriteBuf::<1024>::default();
576
577 let msg1 = Bytes::from("foo bar");
578 encoder.encode(msg1, dst);
579
580 assert_eq!(dst.buf(), b"foo bar");
581
582 for _ in 0..8 {
583 let msg2 = Bytes::from("baz");
584 encoder.encode(msg2, dst);
585
586 assert_eq!(dst.buf().len(), max_len);
587 assert_eq!(dst.buf(), b"foo barb");
588 }
589
590 encoder.encode_eof(dst);
591 assert_eq!(dst.buf().len(), max_len);
592 assert_eq!(dst.buf(), b"foo barb");
593 }
594}