1use crate::util::{unwrap_poison, unwrap_some};
4use crate::{error_log, TiiResult};
5use libflate::gzip;
6use std::fmt::{Debug, Formatter};
7use std::io;
8use std::io::{Cursor, Error, ErrorKind, Read, Take};
9use std::ops::DerefMut;
10use std::sync::{Arc, Mutex};
11
12#[derive(Debug, Clone)]
19#[repr(transparent)]
20pub struct RequestBody(Arc<Mutex<RequestBodyInner>>);
21
22impl Eq for RequestBody {}
23impl PartialEq for RequestBody {
24 fn eq(&self, other: &Self) -> bool {
25 Arc::ptr_eq(&self.0, &other.0)
26 }
27}
28
29impl RequestBody {
30 pub fn new_with_data_ref<T: AsRef<[u8]>>(data: T) -> RequestBody {
33 Self::new_with_data(data.as_ref().to_vec())
34 }
35
36 pub fn new_with_data(data: Vec<u8>) -> RequestBody {
38 let len = data.len() as u64;
39 let cursor = Cursor::new(data);
40 Self::new_with_content_length(Box::new(cursor) as Box<dyn Read + Send + 'static>, len)
41 }
42
43 pub fn new_with_content_length<T: Read + Send + 'static>(read: T, len: u64) -> RequestBody {
45 RequestBody(Arc::new(Mutex::new(RequestBodyInner::WithContentLength(
46 RequestBodyWithContentLength {
47 err: false,
48 data: (Box::new(read) as Box<dyn Read + Send>).take(len),
49 },
50 ))))
51 }
52
53 pub fn new_chunked<T: Read + Send + 'static>(read: T) -> RequestBody {
55 RequestBody(Arc::new(Mutex::new(RequestBodyInner::Chunked(RequestBodyChunked {
56 read: Box::new(read) as Box<dyn Read + Send>,
57 eof: false,
58 err: false,
59 remaining_chunk_length: 0,
60 }))))
61 }
62
63 pub fn new_gzip_chunked<T: Read + Send + 'static>(read: T) -> TiiResult<RequestBody> {
66 let inner = RequestBodyInner::Chunked(RequestBodyChunked {
67 read: Box::new(read) as Box<dyn Read + Send>,
68 eof: false,
69 err: false,
70 remaining_chunk_length: 0,
71 });
72
73 Ok(RequestBody(Arc::new(Mutex::new(RequestBodyInner::Gzip(GzipRequestBody::new(inner)?)))))
74 }
75
76 pub fn new_gzip_with_uncompressed_length<T: Read + Send + 'static>(
79 read: T,
80 len: u64,
81 ) -> TiiResult<RequestBody> {
82 let decoder = gzip::Decoder::new(read).inspect_err(|e| {
83 error_log!("Could not decode gzip header of request body: {}", e);
84 })?;
85
86 Ok(Self::new_with_content_length(decoder, len))
87 }
88
89 pub fn new_gzip_with_compressed_content_length<T: Read + Send + 'static>(
92 read: T,
93 len: u64,
94 ) -> TiiResult<RequestBody> {
95 let inner = RequestBodyInner::WithContentLength(RequestBodyWithContentLength {
96 err: false,
97 data: (Box::new(read) as Box<dyn Read + Send>).take(len),
98 });
99
100 Ok(RequestBody(Arc::new(Mutex::new(RequestBodyInner::Gzip(GzipRequestBody::new(inner)?)))))
101 }
102}
103
104impl RequestBody {
105 pub fn as_read(&self) -> impl Read + '_ {
108 Box::new(self)
109 }
110
111 pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
113 unwrap_poison(self.0.lock())?.deref_mut().read(buf)
114 }
115
116 pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
118 unwrap_poison(self.0.lock())?.deref_mut().read_to_end(buf)
119 }
120
121 pub fn read_to_vec(&self) -> io::Result<Vec<u8>> {
123 let mut buffer = Vec::new();
124 self.read_to_end(&mut buffer)?;
125 Ok(buffer)
126 }
127
128 pub fn read_exact(&self, buf: &mut [u8]) -> io::Result<()> {
130 unwrap_poison(self.0.lock())?.deref_mut().read_exact(buf)
131 }
132
133 pub fn remaining(&self) -> io::Result<Option<u64>> {
138 Ok(match unwrap_poison(self.0.lock())?.deref_mut() {
139 RequestBodyInner::WithContentLength(wc) => Some(wc.data.limit()),
140 _ => None,
141 })
142 }
143}
144
145impl Read for &RequestBody {
146 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
147 RequestBody::read(self, buf) }
149}
150
151#[derive(Debug)]
152enum RequestBodyInner {
153 WithContentLength(RequestBodyWithContentLength),
154 Chunked(RequestBodyChunked),
155 Gzip(GzipRequestBody),
156 }
158
159impl Read for RequestBodyInner {
160 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
161 match self {
162 RequestBodyInner::WithContentLength(body) => body.read(buf),
163 RequestBodyInner::Chunked(body) => body.read(buf),
164 RequestBodyInner::Gzip(body) => body.read(buf),
165 }
166 }
167}
168
169#[derive(Debug)]
170struct GzipRequestBody {
171 err: bool,
172 decoder: Option<gzip::Decoder<Box<RequestBodyInner>>>,
173}
174
175impl GzipRequestBody {
176 fn new(inner: RequestBodyInner) -> TiiResult<Self> {
177 let decoder = gzip::Decoder::new(Box::new(inner)).inspect_err(|e| {
178 error_log!("Could not decode gzip header of request body: {}", e);
179 })?;
180 Ok(Self { err: false, decoder: Some(decoder) })
181 }
182}
183
184impl Read for GzipRequestBody {
185 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
186 if self.err {
187 return Err(Error::new(ErrorKind::BrokenPipe, "Previous IO Error reading body"));
188 }
189 let Some(dec) = self.decoder.as_mut() else {
190 return Ok(0);
191 };
192
193 let count = dec.read(buf).inspect_err(|_| self.err = true)?;
194 if count == 0 {
195 let mut small_buf = [0u8];
197 let count = unwrap_some(self.decoder.take())
198 .into_inner()
199 .read(small_buf.as_mut_slice())
200 .inspect_err(|_| self.err = true)?;
201 if count != 0 {
202 self.err = true;
203 return Err(Error::new(ErrorKind::BrokenPipe, "Gzip decoded did not fully consume data"));
204 }
205 }
206 Ok(count)
207 }
208}
209
210struct RequestBodyWithContentLength {
211 err: bool,
212 data: Take<Box<dyn Read + Send>>,
213}
214
215impl Read for RequestBodyWithContentLength {
216 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
217 if self.err {
218 return Err(Error::new(
219 ErrorKind::BrokenPipe,
220 "Transfer stream has failed due to previous error",
221 ));
222 }
223 self.data.read(buf).inspect_err(|_| self.err = true)
224 }
225}
226
227impl Debug for RequestBodyWithContentLength {
228 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229 f.write_fmt(format_args!("RequestBodyWithContentLength(remaining={})", self.data.limit()))
230 }
231}
232
233struct RequestBodyChunked {
234 read: Box<dyn Read + Send>,
235 eof: bool,
236 err: bool,
237 remaining_chunk_length: u64,
238}
239
240impl Debug for RequestBodyChunked {
241 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242 f.write_fmt(format_args!(
243 "RequestBodyChunked(eof={} remaining_chunk_length={})",
244 self.eof, self.remaining_chunk_length
245 ))
246 }
247}
248
249impl RequestBodyChunked {
250 #[expect(clippy::indexing_slicing, reason = "we break if n >= 17")]
251 fn read_internal(&mut self, buf: &mut [u8]) -> io::Result<usize> {
252 if buf.is_empty() {
253 return Ok(0);
254 }
255
256 if self.eof {
257 return Ok(0);
258 }
259
260 if self.remaining_chunk_length > 0 {
261 let to_read = u64::min(buf.len() as u64, self.remaining_chunk_length) as usize;
262 let read = self.read.read(&mut buf[..to_read])?;
263 if read == 0 {
264 return Err(Error::new(
265 ErrorKind::UnexpectedEof,
266 "chunked transfer encoding suggest more data",
267 ));
268 }
269
270 self.remaining_chunk_length =
271 unwrap_some(self.remaining_chunk_length.checked_sub(read as u64));
272 if self.remaining_chunk_length == 0 {
273 let mut tiny_buffer = [0u8; 1];
274 self.read.read_exact(&mut tiny_buffer)?;
275 if tiny_buffer[0] != b'\r' {
276 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk trailer is malformed"));
277 }
278 self.read.read_exact(&mut tiny_buffer)?;
279 if tiny_buffer[0] != b'\n' {
280 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk trailer is malformed"));
281 }
282 }
283 return Ok(read);
284 }
285
286 let mut small_buffer = [0u8; 32];
287 let mut n = 0;
288 loop {
289 if n >= 17 {
290 return Err(Error::new(
292 io::ErrorKind::InvalidData,
293 "Chunk size is larger than 2^64 or malformed",
294 ));
295 }
296 self.read.read_exact(&mut small_buffer[n..n + 1])?;
297 if small_buffer[n] == b'\r' {
298 self.read.read_exact(&mut small_buffer[n..n + 1])?;
299 if small_buffer[n] != b'\n' {
300 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk size is malformed"));
301 }
302 break;
303 }
304
305 n += 1;
306 }
307
308 if n == 0 {
309 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk size is malformed"));
310 }
311
312 let str = std::str::from_utf8(&small_buffer[0..n])
313 .map_err(|_| Error::new(io::ErrorKind::InvalidData, "Chunk size is malformed"))?;
314 let chunk_len = u64::from_str_radix(str, 16)
315 .map_err(|_| Error::new(io::ErrorKind::InvalidData, "Chunk size is malformed"))?;
316 if chunk_len == 0 {
317 self.read.read_exact(&mut small_buffer[n..n + 1])?;
318 if small_buffer[n] != b'\r' {
319 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk trailer is malformed"));
320 }
321 self.read.read_exact(&mut small_buffer[n..n + 1])?;
322 if small_buffer[n] != b'\n' {
323 return Err(Error::new(io::ErrorKind::InvalidData, "Chunk trailer is malformed"));
324 }
325
326 self.eof = true;
327 return Ok(0);
328 }
329
330 self.remaining_chunk_length = chunk_len;
331 self.read(buf)
332 }
333}
334
335impl Read for RequestBodyChunked {
336 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
337 if self.err {
338 return Err(Error::new(
339 ErrorKind::BrokenPipe,
340 "Chunked transfer stream has failed due to previous error",
341 ));
342 }
343 self.read_internal(buf).inspect_err(|_| self.err = true)
344 }
345}