servlin/
util.rs

1use fixed_buffer::FixedBuf;
2use futures_io::{AsyncRead, AsyncWrite};
3use futures_lite::{AsyncReadExt, AsyncWriteExt};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub enum CopyResult {
8    Ok(u64),
9    ReaderErr(std::io::Error),
10    WriterErr(std::io::Error),
11}
12impl CopyResult {
13    #[allow(clippy::missing_errors_doc)]
14    pub fn map_errs<E>(
15        self,
16        read_err_op: impl FnOnce(std::io::Error) -> E,
17        write_err_op: impl FnOnce(std::io::Error) -> E,
18    ) -> Result<u64, E> {
19        match self {
20            CopyResult::Ok(n) => Ok(n),
21            CopyResult::ReaderErr(e) => Err(read_err_op(e)),
22            CopyResult::WriterErr(e) => Err(write_err_op(e)),
23        }
24    }
25}
26
27/// Copies bytes from `reader` to `writer`.
28pub async fn copy_async(
29    mut reader: impl AsyncRead + Unpin,
30    mut writer: impl AsyncWrite + Unpin,
31) -> CopyResult {
32    let mut buf = Box::pin(<FixedBuf<65536>>::new());
33    let mut num_copied = 0;
34    loop {
35        match reader.read(buf.writable()).await {
36            Ok(0) => return CopyResult::Ok(num_copied),
37            Ok(n) => buf.wrote(n),
38            Err(e) => return CopyResult::ReaderErr(e),
39        }
40        let readable = buf.read_all();
41        match writer.write_all(readable).await {
42            Ok(()) => num_copied += readable.len() as u64,
43            Err(e) => return CopyResult::WriterErr(e),
44        }
45    }
46}
47
48fn hex_digit(n: u8) -> u8 {
49    match n {
50        0 => b'0',
51        1 => b'1',
52        2 => b'2',
53        3 => b'3',
54        4 => b'4',
55        5 => b'5',
56        6 => b'6',
57        7 => b'7',
58        8 => b'8',
59        9 => b'9',
60        10 => b'a',
61        11 => b'b',
62        12 => b'c',
63        13 => b'd',
64        14 => b'e',
65        15 => b'f',
66        _ => unimplemented!(),
67    }
68}
69
70fn trim_prefix(mut slice: &[u8], prefix: u8) -> &[u8] {
71    while !slice.is_empty() && slice[0] == prefix {
72        slice = &slice[1..];
73    }
74    slice
75}
76
77/// Reads blocks from `reader`, encodes them in HTTP chunked encoding, and writes them to `writer`.
78#[allow(clippy::missing_panics_doc)]
79pub async fn copy_chunked_async(
80    mut reader: impl AsyncRead + Unpin,
81    mut writer: impl AsyncWrite + Unpin,
82) -> CopyResult {
83    let mut num_copied = 0;
84    loop {
85        let mut buf = Box::pin([0_u8; 65536]);
86        let len = match reader.read(&mut buf[6..65534]).await {
87            Ok(0) => break,
88            Ok(len) => len,
89            Err(e) => return CopyResult::ReaderErr(e),
90        };
91        buf[0] = hex_digit(u8::try_from((len >> 12) & 0xF).unwrap());
92        buf[1] = hex_digit(u8::try_from((len >> 8) & 0xF).unwrap());
93        buf[2] = hex_digit(u8::try_from((len >> 4) & 0xF).unwrap());
94        buf[3] = hex_digit(u8::try_from(len & 0xF).unwrap());
95        buf[4] = b'\r';
96        buf[5] = b'\n';
97        buf[6 + len] = b'\r';
98        buf[6 + len + 1] = b'\n';
99        let bytes = &buf[..(6 + len + 2)];
100        let bytes = trim_prefix(bytes, b'0');
101        if let Err(e) = writer.write_all(bytes).await {
102            return CopyResult::WriterErr(e);
103        }
104        num_copied += len as u64;
105    }
106    if let Err(e) = writer.write_all(b"0\r\n\r\n").await {
107        return CopyResult::WriterErr(e);
108    }
109    num_copied += 3;
110    CopyResult::Ok(num_copied)
111}
112
113/// Convert a byte slice into a string.
114/// Includes printable ASCII characters as-is.
115/// Converts non-printable or non-ASCII characters to strings like "\n" and "\x19".
116///
117/// Uses
118/// [`core::ascii::escape_default`](https://doc.rust-lang.org/core/ascii/fn.escape_default.html)
119/// internally to escape each byte.
120///
121/// This function is useful for printing byte slices to logs and comparing byte slices in tests.
122///
123/// Example test:
124/// ```
125/// use fixed_buffer::escape_ascii;
126/// assert_eq!("abc", escape_ascii(b"abc"));
127/// assert_eq!("abc\\n", escape_ascii(b"abc\n"));
128/// assert_eq!(
129///     "Euro sign: \\xe2\\x82\\xac",
130///     escape_ascii("Euro sign: \u{20AC}".as_bytes())
131/// );
132/// assert_eq!("\\x01\\x02\\x03", escape_ascii(&[1, 2, 3]));
133/// ```
134#[must_use]
135#[allow(clippy::missing_panics_doc)]
136pub fn escape_ascii(input: &[u8]) -> String {
137    let mut result = String::new();
138    for byte in input {
139        for ascii_byte in core::ascii::escape_default(*byte) {
140            result.push_str(core::str::from_utf8(&[ascii_byte]).unwrap());
141        }
142    }
143    result
144}
145
146#[must_use]
147#[allow(clippy::missing_panics_doc)]
148pub fn escape_and_elide(input: &[u8], max_len: usize) -> String {
149    if input.len() > max_len {
150        escape_ascii(&input[..max_len]) + "..."
151    } else {
152        escape_ascii(input)
153    }
154}
155
156pub fn find_slice<T: std::cmp::PartialEq>(needle: &[T], haystack: &[T]) -> Option<usize> {
157    if needle.len() <= haystack.len() {
158        for n in 0..=(haystack.len() - needle.len()) {
159            if &haystack[n..(n + needle.len())] == needle {
160                return Some(n);
161            }
162        }
163    }
164    None
165}
166
167/// Wraps an `AsyncWrite` and records the number of bytes successfully written to it.
168pub struct AsyncWriteCounter<W>(W, u64);
169impl<W: AsyncWrite + Unpin> AsyncWriteCounter<W> {
170    pub fn new(writer: W) -> Self {
171        Self(writer, 0)
172    }
173
174    pub fn num_bytes_written(&self) -> u64 {
175        self.1
176    }
177}
178impl<W: AsyncWrite + Unpin> AsyncWrite for AsyncWriteCounter<W> {
179    fn poll_write(
180        mut self: Pin<&mut Self>,
181        cx: &mut Context<'_>,
182        buf: &[u8],
183    ) -> Poll<Result<usize, std::io::Error>> {
184        match Pin::new(&mut self.0).poll_write(cx, buf) {
185            Poll::Ready(Ok(n)) => {
186                self.1 += n as u64;
187                Poll::Ready(Ok(n))
188            }
189            other => other,
190        }
191    }
192
193    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
194        Pin::new(&mut self.0).poll_flush(cx)
195    }
196
197    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
198        Pin::new(&mut self.0).poll_close(cx)
199    }
200}