Skip to main content

xitca_http/h1/proto/
buf_write.rs

1use core::convert::Infallible;
2
3use std::io::Write;
4
5use tracing::warn;
6
7use crate::{
8    bytes::{BufMut, BufMutWriter, Bytes, BytesMut},
9    http::header::{self, HeaderMap, HeaderName},
10};
11
12/// trait for add http/1 data to buffer that implement [BufWrite] trait.
13pub trait H1BufWrite {
14    /// write http response head(status code and reason line, header lines) to buffer with fallible
15    /// closure. on error path the buffer is reverted back to state before method was called.
16    #[inline]
17    fn write_buf_head<F, T, E>(&mut self, func: F) -> Result<T, E>
18    where
19        F: FnOnce(&mut BytesMut) -> Result<T, E>,
20    {
21        self.write_buf(func)
22    }
23
24    /// write `&'static [u8]` to buffer.
25    fn write_buf_static(&mut self, bytes: &'static [u8]) {
26        let _ = self.write_buf(|buf| {
27            buf.put_slice(bytes);
28            Ok::<_, Infallible>(())
29        });
30    }
31
32    /// write bytes to buffer as is.
33    fn write_buf_bytes(&mut self, bytes: Bytes) {
34        let _ = self.write_buf(|buf| {
35            buf.put_slice(bytes.as_ref());
36            Ok::<_, Infallible>(())
37        });
38    }
39
40    /// write bytes to buffer as `transfer-encoding: chunked` encoded.
41    fn write_buf_bytes_chunked(&mut self, bytes: Bytes) {
42        let _ = self.write_buf(|buf| {
43            write!(BufMutWriter(buf), "{:X}\r\n", bytes.len()).unwrap();
44            buf.reserve(bytes.len() + 2);
45            buf.put_slice(bytes.as_ref());
46            buf.put_slice(b"\r\n");
47            Ok::<_, Infallible>(())
48        });
49    }
50
51    fn write_buf_trailers(&mut self, trailers: HeaderMap) {
52        let _ = self.write_buf(|buf| {
53            buf.put_slice(b"0\r\n");
54            for (name, value) in trailers.iter() {
55                if is_forbidden_trailer_field(name) {
56                    warn!(target: "h1_encode", "filtered forbidden trailer field: {}", name);
57                    continue;
58                }
59                buf.reserve(name.as_str().len() + 2 + value.len() + 2);
60                buf.put_slice(name.as_str().as_bytes());
61                buf.put_slice(b": ");
62                buf.put_slice(value.as_bytes());
63                buf.put_slice(b"\r\n");
64            }
65            buf.put_slice(b"\r\n");
66            Ok::<_, Infallible>(())
67        });
68    }
69
70    fn write_buf<F, T, E>(&mut self, func: F) -> Result<T, E>
71    where
72        F: FnOnce(&mut BytesMut) -> Result<T, E>;
73}
74
75impl H1BufWrite for BytesMut {
76    fn write_buf<F, T, E>(&mut self, func: F) -> Result<T, E>
77    where
78        F: FnOnce(&mut BytesMut) -> Result<T, E>,
79    {
80        let len = self.len();
81        func(self).inspect_err(|_| self.truncate(len))
82    }
83}
84
85// Returns true if the header name is forbidden in trailers per RFC 9110 ยง6.5.1.
86fn is_forbidden_trailer_field(name: &HeaderName) -> bool {
87    matches!(
88        *name,
89        header::AUTHORIZATION
90            | header::CACHE_CONTROL
91            | header::CONTENT_ENCODING
92            | header::CONTENT_LENGTH
93            | header::CONTENT_RANGE
94            | header::CONTENT_TYPE
95            | header::HOST
96            | header::MAX_FORWARDS
97            | header::SET_COOKIE
98            | header::TRAILER
99            | header::TRANSFER_ENCODING
100            | header::TE
101    )
102}