1use std::{borrow::Borrow, marker::PhantomData, ops::Deref, str::Utf8Error};
4
5use bytes::{Bytes, BytesMut};
6
7#[cfg(feature = "tokio-codec")]
8use tokio_util::codec::{Decoder, Encoder};
9
10use crate::{
11 error::Error,
12 header::{
13 FieldIter, HeaderField, HeaderFieldDecoder, HeaderFieldEncoder, HeaderFieldValue,
14 HeaderFields, Iter,
15 },
16 line::{LineDecoder, LineDecoderOptions},
17 utils::ascii::AsciiExt,
18};
19
20#[cfg(feature = "tokio-codec")]
21use crate::error::CodecError;
22
23#[derive(Debug, Clone)]
25pub struct RequestPath {
26 inner: Bytes,
27}
28
29impl RequestPath {
30 #[inline]
32 pub const fn from_static_str(s: &'static str) -> Self {
33 Self::from_static_bytes(s.as_bytes())
34 }
35
36 #[inline]
38 pub const fn from_static_bytes(s: &'static [u8]) -> Self {
39 Self {
40 inner: Bytes::from_static(s),
41 }
42 }
43
44 #[inline]
46 pub fn to_str(&self) -> Result<&str, Utf8Error> {
47 std::str::from_utf8(&self.inner)
48 }
49}
50
51impl PartialEq for RequestPath {
52 #[inline]
53 fn eq(&self, other: &Self) -> bool {
54 self.inner.eq(&other.inner)
55 }
56}
57
58impl Eq for RequestPath {}
59
60impl AsRef<[u8]> for RequestPath {
61 #[inline]
62 fn as_ref(&self) -> &[u8] {
63 &self.inner
64 }
65}
66
67impl Borrow<[u8]> for RequestPath {
68 #[inline]
69 fn borrow(&self) -> &[u8] {
70 &self.inner
71 }
72}
73
74impl Deref for RequestPath {
75 type Target = [u8];
76
77 #[inline]
78 fn deref(&self) -> &Self::Target {
79 &self.inner
80 }
81}
82
83impl From<&'static [u8]> for RequestPath {
84 #[inline]
85 fn from(s: &'static [u8]) -> Self {
86 Self::from(Bytes::from(s))
87 }
88}
89
90impl From<&'static str> for RequestPath {
91 #[inline]
92 fn from(s: &'static str) -> Self {
93 Self::from(Bytes::from(s))
94 }
95}
96
97impl From<Bytes> for RequestPath {
98 #[inline]
99 fn from(bytes: Bytes) -> Self {
100 Self { inner: bytes }
101 }
102}
103
104impl From<BytesMut> for RequestPath {
105 #[inline]
106 fn from(bytes: BytesMut) -> Self {
107 Self::from(Bytes::from(bytes))
108 }
109}
110
111impl From<Box<[u8]>> for RequestPath {
112 #[inline]
113 fn from(bytes: Box<[u8]>) -> Self {
114 Self::from(Bytes::from(bytes))
115 }
116}
117
118impl From<Vec<u8>> for RequestPath {
119 #[inline]
120 fn from(bytes: Vec<u8>) -> Self {
121 Self::from(Bytes::from(bytes))
122 }
123}
124
125impl From<String> for RequestPath {
126 #[inline]
127 fn from(s: String) -> Self {
128 Self::from(Bytes::from(s))
129 }
130}
131
132struct InvalidRequestLine;
134
135impl From<InvalidRequestLine> for Error {
136 fn from(_: InvalidRequestLine) -> Self {
137 Error::from_static_msg("invalid request line")
138 }
139}
140
141#[derive(Clone)]
143pub struct RequestHeaderBuilder<P = Bytes, V = Bytes, M = Bytes> {
144 header: RequestHeader<P, V, M>,
145}
146
147impl<P, V, M> RequestHeaderBuilder<P, V, M> {
148 #[inline]
150 pub fn set_version(mut self, version: V) -> Self {
151 self.header.version = version;
152 self
153 }
154
155 #[inline]
157 pub fn set_method(mut self, method: M) -> Self {
158 self.header.method = method;
159 self
160 }
161
162 #[inline]
164 pub fn set_path(mut self, path: RequestPath) -> Self {
165 self.header.path = path;
166 self
167 }
168
169 pub fn set_header_field<T>(mut self, field: T) -> Self
171 where
172 T: Into<HeaderField>,
173 {
174 self.header.header_fields.set(field);
175 self
176 }
177
178 pub fn add_header_field<T>(mut self, field: T) -> Self
180 where
181 T: Into<HeaderField>,
182 {
183 self.header.header_fields.add(field);
184 self
185 }
186
187 pub fn remove_header_fields<N>(mut self, name: &N) -> Self
189 where
190 N: AsRef<[u8]> + ?Sized,
191 {
192 self.header.header_fields.remove(name);
193 self
194 }
195
196 #[inline]
198 pub fn build(self) -> RequestHeader<P, V, M> {
199 self.header
200 }
201}
202
203impl<P, V, M> From<RequestHeader<P, V, M>> for RequestHeaderBuilder<P, V, M> {
204 #[inline]
205 fn from(header: RequestHeader<P, V, M>) -> Self {
206 Self { header }
207 }
208}
209
210#[derive(Debug, Clone)]
215pub struct RequestHeader<P = Bytes, V = Bytes, M = Bytes> {
216 method: M,
217 path: RequestPath,
218 protocol: P,
219 version: V,
220 header_fields: HeaderFields,
221}
222
223impl RequestHeader {
224 fn parse_request_line(line: Bytes) -> Result<Self, InvalidRequestLine> {
226 let (method, rest) = line
227 .trim_ascii_start()
228 .split_once(|b| b.is_ascii_whitespace())
229 .ok_or(InvalidRequestLine)?;
230
231 let (path, rest) = rest
232 .trim_ascii_start()
233 .split_once(|b| b.is_ascii_whitespace())
234 .ok_or(InvalidRequestLine)?;
235
236 let (protocol, version) = rest.split_once(|b| b == b'/').ok_or(InvalidRequestLine)?;
237
238 let res = Self {
239 method,
240 path: path.into(),
241 protocol: protocol.trim_ascii(),
242 version: version.trim_ascii(),
243 header_fields: HeaderFields::new(),
244 };
245
246 Ok(res)
247 }
248
249 fn parse_request_parts<P, V, M>(self) -> Result<RequestHeader<P, V, M>, Error>
251 where
252 P: TryFrom<Bytes>,
253 V: TryFrom<Bytes>,
254 M: TryFrom<Bytes>,
255 Error: From<P::Error>,
256 Error: From<V::Error>,
257 Error: From<M::Error>,
258 {
259 let protocol = P::try_from(self.protocol)?;
260 let version = V::try_from(self.version)?;
261 let method = M::try_from(self.method)?;
262
263 let res = RequestHeader {
264 method,
265 path: self.path,
266 protocol,
267 version,
268 header_fields: self.header_fields,
269 };
270
271 Ok(res)
272 }
273}
274
275impl<P, V, M> RequestHeader<P, V, M> {
276 #[inline]
285 pub const fn new(protocol: P, version: V, method: M, path: RequestPath) -> Self {
286 Self {
287 method,
288 path,
289 protocol,
290 version,
291 header_fields: HeaderFields::new(),
292 }
293 }
294
295 #[inline]
304 pub const fn builder(
305 protocol: P,
306 version: V,
307 method: M,
308 path: RequestPath,
309 ) -> RequestHeaderBuilder<P, V, M> {
310 RequestHeaderBuilder {
311 header: Self::new(protocol, version, method, path),
312 }
313 }
314
315 #[inline]
317 pub fn method(&self) -> &M {
318 &self.method
319 }
320
321 #[inline]
323 pub fn protocol(&self) -> &P {
324 &self.protocol
325 }
326
327 #[inline]
329 pub fn version(&self) -> &V {
330 &self.version
331 }
332
333 #[inline]
335 pub fn path(&self) -> &RequestPath {
336 &self.path
337 }
338
339 #[inline]
341 pub fn get_all_header_fields(&self) -> Iter<'_> {
342 self.header_fields.all()
343 }
344
345 pub fn get_header_fields<'a, N>(&'a self, name: &'a N) -> FieldIter<'a>
347 where
348 N: AsRef<[u8]> + ?Sized,
349 {
350 self.header_fields.get(name)
351 }
352
353 pub fn get_header_field<'a, N>(&'a self, name: &'a N) -> Option<&'a HeaderField>
355 where
356 N: AsRef<[u8]> + ?Sized,
357 {
358 self.header_fields.last(name)
359 }
360
361 pub fn get_header_field_value<'a, N>(&'a self, name: &'a N) -> Option<&'a HeaderFieldValue>
363 where
364 N: AsRef<[u8]> + ?Sized,
365 {
366 self.header_fields.last_value(name)
367 }
368}
369
370pub struct RequestHeaderEncoder(());
372
373impl RequestHeaderEncoder {
374 #[inline]
376 pub const fn new() -> Self {
377 Self(())
378 }
379
380 pub fn encode<P, V, M>(&mut self, header: &RequestHeader<P, V, M>, dst: &mut BytesMut)
382 where
383 P: AsRef<[u8]>,
384 V: AsRef<[u8]>,
385 M: AsRef<[u8]>,
386 {
387 fn inner(
389 method: &[u8],
390 path: &[u8],
391 protocol: &[u8],
392 version: &[u8],
393 fields: &HeaderFields,
394 dst: &mut BytesMut,
395 ) {
396 let mut hfe = HeaderFieldEncoder::new();
397
398 let len = 7
399 + method.len()
400 + path.len()
401 + protocol.len()
402 + version.len()
403 + fields
404 .all()
405 .map(|f| 2 + hfe.get_encoded_length(f))
406 .sum::<usize>();
407
408 dst.reserve(len);
409
410 dst.extend_from_slice(method);
411 dst.extend_from_slice(b" ");
412
413 dst.extend_from_slice(path);
414 dst.extend_from_slice(b" ");
415
416 dst.extend_from_slice(protocol);
417 dst.extend_from_slice(b"/");
418 dst.extend_from_slice(version);
419 dst.extend_from_slice(b"\r\n");
420
421 for field in fields.all() {
422 hfe.encode(field, dst);
423 dst.extend_from_slice(b"\r\n");
424 }
425
426 dst.extend_from_slice(b"\r\n");
427 }
428
429 let method = header.method.as_ref();
430 let path = header.path.as_ref();
431 let protocol = header.protocol.as_ref();
432 let version = header.version.as_ref();
433
434 inner(method, path, protocol, version, &header.header_fields, dst)
435 }
436}
437
438impl Default for RequestHeaderEncoder {
439 #[inline]
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445#[cfg(feature = "tokio-codec")]
446#[cfg_attr(docsrs, doc(cfg(feature = "tokio-codec")))]
447impl<P, V, M> Encoder<&RequestHeader<P, V, M>> for RequestHeaderEncoder
448where
449 P: AsRef<[u8]>,
450 V: AsRef<[u8]>,
451 M: AsRef<[u8]>,
452{
453 type Error = CodecError;
454
455 #[inline]
456 fn encode(
457 &mut self,
458 header: &RequestHeader<P, V, M>,
459 dst: &mut BytesMut,
460 ) -> Result<(), Self::Error> {
461 RequestHeaderEncoder::encode(self, header, dst);
462
463 Ok(())
464 }
465}
466
467#[derive(Copy, Clone)]
469pub struct RequestHeaderDecoderOptions {
470 line_decoder_options: LineDecoderOptions,
471 max_header_field_length: Option<usize>,
472 max_header_fields: Option<usize>,
473}
474
475impl RequestHeaderDecoderOptions {
476 #[inline]
482 pub const fn new() -> Self {
483 let line_decoder_options = LineDecoderOptions::new()
484 .cr(false)
485 .lf(false)
486 .crlf(true)
487 .max_line_length(Some(4096))
488 .require_terminator(false);
489
490 Self {
491 line_decoder_options,
492 max_header_field_length: Some(4096),
493 max_header_fields: Some(64),
494 }
495 }
496
497 #[inline]
499 pub const fn accept_all_line_endings(mut self, enabled: bool) -> Self {
500 self.line_decoder_options = self.line_decoder_options.cr(enabled).lf(enabled).crlf(true);
501
502 self
503 }
504
505 #[inline]
507 pub const fn max_line_length(mut self, max_length: Option<usize>) -> Self {
508 self.line_decoder_options = self.line_decoder_options.max_line_length(max_length);
509 self
510 }
511
512 #[inline]
514 pub const fn max_header_field_length(mut self, max_length: Option<usize>) -> Self {
515 self.max_header_field_length = max_length;
516 self
517 }
518
519 #[inline]
521 pub const fn max_header_fields(mut self, max_fields: Option<usize>) -> Self {
522 self.max_header_fields = max_fields;
523 self
524 }
525}
526
527impl Default for RequestHeaderDecoderOptions {
528 #[inline]
529 fn default() -> Self {
530 Self::new()
531 }
532}
533
534pub struct RequestHeaderDecoder<P, V, M> {
536 inner: InternalRequestHeaderDecoder,
537 _pd: PhantomData<(P, V, M)>,
538}
539
540impl<P, V, M> RequestHeaderDecoder<P, V, M> {
541 pub fn new(options: RequestHeaderDecoderOptions) -> Self {
543 Self {
544 inner: InternalRequestHeaderDecoder::new(options),
545 _pd: PhantomData,
546 }
547 }
548
549 pub fn reset(&mut self) {
551 self.inner.reset();
552 }
553}
554
555impl<P, V, M> RequestHeaderDecoder<P, V, M>
556where
557 P: TryFrom<Bytes>,
558 V: TryFrom<Bytes>,
559 M: TryFrom<Bytes>,
560 Error: From<P::Error>,
561 Error: From<V::Error>,
562 Error: From<M::Error>,
563{
564 pub fn decode(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader<P, V, M>>, Error> {
566 let res = self
567 .inner
568 .decode(data)?
569 .map(RequestHeader::parse_request_parts)
570 .transpose()?;
571
572 Ok(res)
573 }
574
575 pub fn decode_eof(
577 &mut self,
578 data: &mut BytesMut,
579 ) -> Result<Option<RequestHeader<P, V, M>>, Error> {
580 let res = self
581 .inner
582 .decode_eof(data)?
583 .map(RequestHeader::parse_request_parts)
584 .transpose()?;
585
586 Ok(res)
587 }
588}
589
590#[cfg(feature = "tokio-codec")]
591#[cfg_attr(docsrs, doc(cfg(feature = "tokio-codec")))]
592impl<P, V, M> Decoder for RequestHeaderDecoder<P, V, M>
593where
594 P: TryFrom<Bytes>,
595 V: TryFrom<Bytes>,
596 M: TryFrom<Bytes>,
597 Error: From<P::Error>,
598 Error: From<V::Error>,
599 Error: From<M::Error>,
600{
601 type Item = RequestHeader<P, V, M>;
602 type Error = CodecError;
603
604 #[inline]
605 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
606 RequestHeaderDecoder::<P, V, M>::decode(self, buf).map_err(CodecError::Other)
607 }
608
609 #[inline]
610 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
611 RequestHeaderDecoder::<P, V, M>::decode_eof(self, buf).map_err(CodecError::Other)
612 }
613}
614
615struct InternalRequestHeaderDecoder {
617 line_decoder: LineDecoder,
618 header: Option<RequestHeader>,
619 field_decoder: HeaderFieldDecoder,
620 max_header_fields: Option<usize>,
621}
622
623impl InternalRequestHeaderDecoder {
624 fn new(options: RequestHeaderDecoderOptions) -> Self {
626 Self {
627 line_decoder: LineDecoder::new(options.line_decoder_options),
628 header: None,
629 field_decoder: HeaderFieldDecoder::new(options.max_header_field_length),
630 max_header_fields: options.max_header_fields,
631 }
632 }
633
634 fn decode(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader>, Error> {
636 while let Some(line) = self.line_decoder.decode(data)? {
637 if let Some(header) = self.decode_line(line)? {
638 return Ok(Some(header));
639 }
640 }
641
642 Ok(None)
643 }
644
645 fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader>, Error> {
647 while let Some(line) = self.line_decoder.decode_eof(data)? {
648 if let Some(header) = self.decode_line(line)? {
649 return Ok(Some(header));
650 }
651 }
652
653 if data.is_empty() && self.line_decoder.is_empty() && self.header.is_none() {
654 Ok(None)
655 } else {
656 Err(Error::from_static_msg("incomplete request header"))
657 }
658 }
659
660 fn decode_line(&mut self, line: Bytes) -> Result<Option<RequestHeader>, Error> {
662 if let Some(header) = self.header.as_mut() {
663 let is_empty_line = line.is_empty();
664
665 if let Some(field) = self.field_decoder.decode(line)? {
666 if let Some(max_fields) = self.max_header_fields {
667 if header.header_fields.len() >= max_fields {
668 return Err(Error::from_static_msg(
669 "maximum number of header fields exceeded",
670 ));
671 }
672 }
673
674 header.header_fields.add(field);
675 }
676
677 if is_empty_line {
679 return Ok(self.take());
680 }
681 } else {
682 self.header = Some(RequestHeader::parse_request_line(line)?);
683 }
684
685 Ok(None)
686 }
687
688 fn reset(&mut self) {
690 self.take();
691 }
692
693 fn take(&mut self) -> Option<RequestHeader> {
695 self.line_decoder.reset();
696 self.field_decoder.reset();
697
698 self.header.take()
699 }
700}