servlin/
util.rs

1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6pub enum CopyResult {
7    Ok(u64),
8    ReaderErr(std::io::Error),
9    WriterErr(std::io::Error),
10}
11impl CopyResult {
12    #[allow(clippy::missing_errors_doc)]
13    pub fn map_errs<E>(
14        self,
15        read_err_op: impl FnOnce(std::io::Error) -> E,
16        write_err_op: impl FnOnce(std::io::Error) -> E,
17    ) -> Result<u64, E> {
18        match self {
19            CopyResult::Ok(n) => Ok(n),
20            CopyResult::ReaderErr(e) => Err(read_err_op(e)),
21            CopyResult::WriterErr(e) => Err(write_err_op(e)),
22        }
23    }
24}
25
26/// Copies bytes from `reader` to `writer`.
27#[allow(clippy::missing_panics_doc)]
28pub async fn copy_async(
29    mut reader: impl AsyncRead + Unpin,
30    mut writer: impl AsyncWrite + Unpin,
31    expected_len: u64,
32) -> CopyResult {
33    let block_len = usize::try_from(expected_len)
34        .unwrap_or(usize::MAX)
35        .min(65536);
36    let mut buf: Vec<u8> = vec![0; block_len];
37    let mut num_copied = 0;
38    loop {
39        let num_read = match reader.read(buf.as_mut_slice()).await {
40            Ok(0) => return CopyResult::Ok(num_copied),
41            Ok(n) => n,
42            Err(e) => return CopyResult::ReaderErr(e),
43        };
44        match writer.write_all(&buf[..num_read]).await {
45            Ok(()) => num_copied += u64::try_from(num_read).unwrap(),
46            Err(e) => return CopyResult::WriterErr(e),
47        }
48    }
49}
50
51fn hex_digit(n: u8) -> u8 {
52    match n {
53        0 => b'0',
54        1 => b'1',
55        2 => b'2',
56        3 => b'3',
57        4 => b'4',
58        5 => b'5',
59        6 => b'6',
60        7 => b'7',
61        8 => b'8',
62        9 => b'9',
63        10 => b'a',
64        11 => b'b',
65        12 => b'c',
66        13 => b'd',
67        14 => b'e',
68        15 => b'f',
69        _ => unimplemented!(),
70    }
71}
72
73fn trim_prefix(mut slice: &[u8], prefix: u8) -> &[u8] {
74    while !slice.is_empty() && slice[0] == prefix {
75        slice = &slice[1..];
76    }
77    slice
78}
79
80/// Reads blocks from `reader`, encodes them in HTTP chunked encoding, and writes them to `writer`.
81#[allow(clippy::missing_panics_doc)]
82pub async fn copy_chunked_async(
83    mut reader: impl AsyncRead + Unpin,
84    mut writer: impl AsyncWrite + Unpin,
85) -> CopyResult {
86    let mut num_copied = 0;
87    loop {
88        #[allow(clippy::large_stack_arrays)]
89        let mut buf = Box::pin([0_u8; 65536]);
90        let len = match reader.read(&mut buf[6..65534]).await {
91            Ok(0) => break,
92            Ok(len) => len,
93            Err(e) => return CopyResult::ReaderErr(e),
94        };
95        buf[0] = hex_digit(u8::try_from((len >> 12) & 0xF).unwrap());
96        buf[1] = hex_digit(u8::try_from((len >> 8) & 0xF).unwrap());
97        buf[2] = hex_digit(u8::try_from((len >> 4) & 0xF).unwrap());
98        buf[3] = hex_digit(u8::try_from(len & 0xF).unwrap());
99        buf[4] = b'\r';
100        buf[5] = b'\n';
101        buf[6 + len] = b'\r';
102        buf[6 + len + 1] = b'\n';
103        let bytes = &buf[..(6 + len + 2)];
104        let bytes = trim_prefix(bytes, b'0');
105        if let Err(e) = writer.write_all(bytes).await {
106            return CopyResult::WriterErr(e);
107        }
108        num_copied += len as u64;
109    }
110    if let Err(e) = writer.write_all(b"0\r\n\r\n").await {
111        return CopyResult::WriterErr(e);
112    }
113    num_copied += 3;
114    CopyResult::Ok(num_copied)
115}
116
117/// Convert a byte slice into a string.
118/// Includes printable ASCII characters as-is.
119/// Converts non-printable or non-ASCII characters to strings like "\n" and "\x19".
120///
121/// Uses
122/// [`core::ascii::escape_default`](https://doc.rust-lang.org/core/ascii/fn.escape_default.html)
123/// internally to escape each byte.
124///
125/// This function is useful for printing byte slices to logs and comparing byte slices in tests.
126///
127/// Example test:
128/// ```
129/// use fixed_buffer::escape_ascii;
130/// assert_eq!("abc", escape_ascii(b"abc"));
131/// assert_eq!("abc\\n", escape_ascii(b"abc\n"));
132/// assert_eq!(
133///     "Euro sign: \\xe2\\x82\\xac",
134///     escape_ascii("Euro sign: \u{20AC}".as_bytes())
135/// );
136/// assert_eq!("\\x01\\x02\\x03", escape_ascii(&[1, 2, 3]));
137/// ```
138#[must_use]
139#[allow(clippy::missing_panics_doc)]
140pub fn escape_ascii(input: &[u8]) -> String {
141    let mut result = String::new();
142    for byte in input {
143        for ascii_byte in core::ascii::escape_default(*byte) {
144            result.push_str(core::str::from_utf8(&[ascii_byte]).unwrap());
145        }
146    }
147    result
148}
149
150#[must_use]
151#[allow(clippy::missing_panics_doc)]
152pub fn escape_and_elide(input: &[u8], max_len: usize) -> String {
153    if input.len() > max_len {
154        escape_ascii(&input[..max_len]) + "..."
155    } else {
156        escape_ascii(input)
157    }
158}
159
160pub fn find_slice<T: PartialEq>(needle: &[T], haystack: &[T]) -> Option<usize> {
161    if needle.len() <= haystack.len() {
162        for n in 0..=(haystack.len() - needle.len()) {
163            if &haystack[n..(n + needle.len())] == needle {
164                return Some(n);
165            }
166        }
167    }
168    None
169}
170
171/// Wraps an `AsyncWrite` and records the number of bytes successfully written to it.
172pub struct AsyncWriteCounter<W>(W, u64);
173impl<W: AsyncWrite + Unpin> AsyncWriteCounter<W> {
174    pub fn new(writer: W) -> Self {
175        Self(writer, 0)
176    }
177
178    pub fn num_bytes_written(&self) -> u64 {
179        self.1
180    }
181}
182impl<W: AsyncWrite + Unpin> AsyncWrite for AsyncWriteCounter<W> {
183    fn poll_write(
184        mut self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &[u8],
187    ) -> Poll<Result<usize, std::io::Error>> {
188        match Pin::new(&mut self.0).poll_write(cx, buf) {
189            Poll::Ready(Ok(n)) => {
190                self.1 += n as u64;
191                Poll::Ready(Ok(n))
192            }
193            other => other,
194        }
195    }
196
197    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
198        Pin::new(&mut self.0).poll_flush(cx)
199    }
200
201    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
202        Pin::new(&mut self.0).poll_close(cx)
203    }
204}