1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5 fmt::{self, Debug, Formatter},
6 io, mem,
7 pin::Pin,
8 task::{Context, Poll, ready},
9};
10use trillium_http::{
11 Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15pub struct ResponseBody<'a> {
46 inner: ResponseBodyInner<'a>,
47 cleanup: Option<CleanupContext>,
54 trailers: Option<Headers>,
59}
60
61#[allow(clippy::large_enum_variant)]
62enum ResponseBodyInner<'a> {
63 Received(ReceivedBody<'a, Box<dyn Transport>>),
64 Override(OverrideBody<'a>),
65 Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
66 Closed,
67}
68
69type H1Pool = Pool<Origin, Box<dyn Transport>>;
70
71#[derive(Clone)]
81pub(crate) struct CleanupContext {
82 pub(crate) runtime: Runtime,
83 pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
84}
85
86impl CleanupContext {
87 pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
90 match &self.h1_pool_origin {
91 Some((pool, origin)) => {
92 log::trace!("body transferred, returning to pool");
93 pool.insert(origin.clone(), PoolEntry::new(transport, None));
94 }
95 None => {
96 self.runtime.clone().spawn(async move {
97 let _ = transport.close().await;
98 });
99 }
100 }
101 }
102}
103
104pub(crate) struct OverrideBody<'a> {
105 body: MutCow<'a, Body>,
106 encoding: &'static Encoding,
107 max_len: u64,
108 initial_len: usize,
109 max_preallocate: usize,
110}
111
112impl AsyncRead for OverrideBody<'_> {
113 fn poll_read(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &mut [u8],
117 ) -> Poll<io::Result<usize>> {
118 let remaining = self.max_len.saturating_sub(self.body.bytes_read());
119 if remaining == 0 && !buf.is_empty() {
120 return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
121 self.max_len,
122 ))));
123 }
124 let cap = remaining.min(buf.len() as u64) as usize;
125 Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
126 }
127}
128
129impl<'a> OverrideBody<'a> {
130 pub(crate) fn new(
131 body: impl Into<MutCow<'a, Body>>,
132 encoding: &'static Encoding,
133 http_config: &HttpConfig,
134 ) -> Self {
135 Self {
136 body: body.into(),
137 encoding,
138 max_len: http_config.received_body_max_len(),
139 max_preallocate: http_config.received_body_max_preallocate(),
140 initial_len: http_config.received_body_initial_len(),
141 }
142 }
143}
144
145impl Debug for ResponseBody<'_> {
146 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147 match &self.inner {
148 ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
149 ResponseBodyInner::Override(o) => f
150 .debug_struct("ResponseBody (override)")
151 .field("body", &*o.body)
152 .field("encoding", &o.encoding.name())
153 .field("max_len", &o.max_len)
154 .finish(),
155 ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
156 ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
157 }
158 }
159}
160
161impl AsyncRead for ResponseBody<'_> {
162 fn poll_read(
163 mut self: Pin<&mut Self>,
164 cx: &mut Context<'_>,
165 buf: &mut [u8],
166 ) -> Poll<io::Result<usize>> {
167 let mut bytes = 0;
168 loop {
169 match &mut self.inner {
170 ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
171 ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
172 ResponseBodyInner::Closing(fut) => {
173 ready!(fut.as_mut().poll(cx));
174 self.inner = ResponseBodyInner::Closed;
175 break;
176 }
177
178 ResponseBodyInner::Closed => break,
179 };
180
181 if bytes == 0
184 && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
185 && rb.state() == ReceivedBodyState::End
186 && let Some(mut transport) = rb.take_transport()
187 {
188 self.trailers = Pin::new(&mut rb).trailers();
189 if let Some((pool, origin)) = cleanup.h1_pool_origin {
190 pool.insert(origin, PoolEntry::new(transport, None));
191 } else {
192 self.inner = ResponseBodyInner::Closing(Box::pin(async move {
193 if let Err(e) = transport.close().await {
194 log::warn!("transport close failed during ResponseBody EOF: {e}");
195 }
196 }));
197 }
198 } else {
199 break;
200 }
201 }
202
203 Poll::Ready(Ok(bytes))
204 }
205}
206
207impl ResponseBody<'_> {
208 fn take_inner(&mut self) -> ResponseBodyInner<'_> {
209 mem::replace(&mut self.inner, ResponseBodyInner::Closed)
210 }
211
212 fn max_preallocate(&self) -> usize {
213 match &self.inner {
214 ResponseBodyInner::Received(rb) => rb.max_preallocate(),
215 ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
216 _ => 0,
217 }
218 }
219
220 fn max_len(&self) -> u64 {
221 match &self.inner {
222 ResponseBodyInner::Received(rb) => rb.max_len(),
223 ResponseBodyInner::Override(override_body) => override_body.max_len,
224 _ => 0,
225 }
226 }
227
228 fn initial_len(&self) -> usize {
229 match &self.inner {
230 ResponseBodyInner::Received(rb) => rb.initial_len(),
231 ResponseBodyInner::Override(override_body) => override_body.initial_len,
232 _ => 0,
233 }
234 }
235
236 fn encoding(&self) -> &'static Encoding {
237 match &self.inner {
238 ResponseBodyInner::Received(rb) => rb.encoding(),
239 ResponseBodyInner::Override(override_body) => override_body.encoding,
240 _ => encoding_rs::WINDOWS_1252,
241 }
242 }
243
244 pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
261 let mut vec = if let Some(len) = self.content_length() {
262 if len > self.max_len() {
263 return Err(Error::ReceivedBodyTooLong(self.max_len()));
264 }
265
266 let len =
267 usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
268
269 Vec::with_capacity(len.min(self.max_preallocate()))
270 } else {
271 Vec::with_capacity(self.initial_len())
272 };
273
274 self.read_to_end(&mut vec).await?;
275
276 Ok(vec)
277 }
278
279 pub async fn read_string(self) -> Result<String, Error> {
299 let encoding = self.encoding();
300 let bytes = self.read_bytes().await?;
301 let (s, _, _) = encoding.decode(&bytes);
302 Ok(s.to_string())
303 }
304
305 #[must_use]
315 pub fn with_max_len(mut self, max_len: u64) -> Self {
316 self.set_max_len(max_len);
317 self
318 }
319
320 pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
330 match &mut self.inner {
331 ResponseBodyInner::Received(rb) => {
332 rb.set_max_len(max_len);
333 }
334 ResponseBodyInner::Override(o) => {
335 o.max_len = max_len;
336 }
337 _ => {}
338 }
339 self
340 }
341
342 pub fn trailers(&self) -> Option<&Headers> {
350 match &self.inner {
351 ResponseBodyInner::Received(rb) => rb.trailers_ref(),
352 _ => self.trailers.as_ref(),
354 }
355 }
356
357 pub fn content_length(&self) -> Option<u64> {
362 match &self.inner {
363 ResponseBodyInner::Received(rb) => rb.content_length(),
364 ResponseBodyInner::Override(o) => o.body.len(),
365 _ => None,
366 }
367 }
368
369 fn prepare_for_recycle(
370 &mut self,
371 ) -> Option<(
372 ReceivedBody<'static, Box<dyn Transport + 'static>>,
373 CleanupContext,
374 )> {
375 let cleanup = self.cleanup.take()?;
376
377 let ResponseBodyInner::Received(rb) = self.take_inner() else {
378 return None;
379 };
380
381 let rb = rb.try_into_owned()?;
382
383 Some((rb, cleanup))
384 }
385}
386
387async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
388 let copy_loops_per_yield = rb.copy_loops_per_yield();
389 trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
390}
391
392async fn recycle(
393 mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
394 h1_pool_origin: Option<(H1Pool, Origin)>,
395) {
396 if let Some((pool, origin)) = h1_pool_origin {
397 match drain(&mut rb).await {
398 Ok(drained) => {
399 if rb.state() == ReceivedBodyState::End
400 && let Some(transport) = rb.take_transport()
401 {
402 log::trace!(
403 "drained {drained} bytes, returning transport to pool for {origin:?}"
404 );
405 pool.insert(origin, PoolEntry::new(transport, None));
406 return;
407 }
408 }
409 Err(e) => log::warn!("drain failed during recycle: {e}"),
410 }
411 }
412
413 if let Some(mut transport) = rb.take_transport()
414 && let Err(e) = transport.close().await
415 {
416 log::warn!("transport close failed during recycle: {e}");
417 }
418}
419
420impl Drop for ResponseBody<'_> {
421 fn drop(&mut self) {
422 let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
423 return;
424 };
425
426 if rb.state() == ReceivedBodyState::End
428 && cleanup.h1_pool_origin.is_some()
429 && let Some(transport) = rb.take_transport()
430 && let Some((pool, origin)) = cleanup.h1_pool_origin
431 {
432 pool.insert(origin, PoolEntry::new(transport, None));
433 } else {
434 cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
435 }
436 }
437}
438
439impl BodySource for ResponseBody<'static> {
440 fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
441 let this = self.get_mut();
442 match &mut this.inner {
443 ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
444 ResponseBodyInner::Override(o) => o.body.trailers(),
445 _ => this.trailers.take(),
448 }
449 }
450}
451
452impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
453 fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
454 Self {
455 inner: ResponseBodyInner::Received(received_body),
456 cleanup: None,
457 trailers: None,
458 }
459 }
460}
461
462impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
463 fn from(o: OverrideBody<'a>) -> Self {
464 Self {
465 inner: ResponseBodyInner::Override(o),
466 cleanup: None,
467 trailers: None,
468 }
469 }
470}
471
472impl ResponseBody<'static> {
473 pub(crate) fn received_owned(
474 body: ReceivedBody<'static, Box<dyn Transport>>,
475 cleanup: CleanupContext,
476 ) -> Self {
477 Self {
478 inner: ResponseBodyInner::Received(body),
479 cleanup: Some(cleanup),
480 trailers: None,
481 }
482 }
483
484 pub async fn recycle(mut self) {
495 let Some((rb, cleanup)) = self.prepare_for_recycle() else {
496 return;
497 };
498
499 recycle(rb, cleanup.h1_pool_origin).await;
500 }
501}
502
503impl<'a> IntoFuture for ResponseBody<'a> {
504 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
505 type Output = trillium_http::Result<String>;
506
507 fn into_future(self) -> Self::IntoFuture {
508 Box::pin(async move { self.read_string().await })
509 }
510}
511
512const _: fn() = || {
513 fn assert_send_sync<T: Send + Sync + ?Sized>() {}
514 assert_send_sync::<ResponseBody<'static>>();
515 assert_send_sync::<ResponseBody<'_>>();
516};