witchcraft_server/blocking/
body.rs

1// Copyright 2022 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::body::ClientIo;
15use crate::server::RawBody;
16use bytes::{Buf, Bytes, BytesMut};
17use conjure_error::Error;
18use futures_channel::mpsc;
19use futures_util::{Future, SinkExt};
20use http::HeaderMap;
21use http_body::Frame;
22use http_body_util::BodyExt;
23use std::io::{BufRead, Read, Write};
24use std::time::Duration;
25use std::{error, io, mem};
26use tokio::runtime::Handle;
27use tokio::time;
28
29const IO_TIMEOUT: Duration = Duration::from_secs(60);
30
31/// A streaming request body for blocking requests.
32pub struct RequestBody {
33    inner: RawBody,
34    handle: Handle,
35    cur: Bytes,
36    trailers: Option<HeaderMap>,
37}
38
39impl RequestBody {
40    pub(crate) fn new(inner: RawBody, handle: Handle) -> Self {
41        RequestBody {
42            inner,
43            handle,
44            cur: Bytes::new(),
45            trailers: None,
46        }
47    }
48
49    /// Returns the request's trailers, if any are present.
50    ///
51    /// The body must have been completely read before this is called.
52    pub fn trailers(&mut self) -> Option<HeaderMap> {
53        self.trailers.take()
54    }
55
56    fn next_raw(&mut self) -> Result<Option<Bytes>, Box<dyn error::Error + Sync + Send>> {
57        loop {
58            let next = self
59                .handle
60                .block_on(async { time::timeout(IO_TIMEOUT, self.inner.frame()).await })?
61                .transpose()?;
62
63            let Some(next) = next else {
64                return Ok(None);
65            };
66
67            let next = match next.into_data() {
68                Ok(data) => return Ok(Some(data)),
69                Err(next) => next,
70            };
71
72            if let Ok(trailers) = next.into_trailers() {
73                match &mut self.trailers {
74                    Some(base) => base.extend(trailers),
75                    None => self.trailers = Some(trailers),
76                }
77            }
78        }
79    }
80}
81
82impl Iterator for RequestBody {
83    type Item = Result<Bytes, Error>;
84
85    fn next(&mut self) -> Option<Self::Item> {
86        if self.cur.has_remaining() {
87            return Some(Ok(mem::take(&mut self.cur)));
88        }
89
90        self.next_raw()
91            .map_err(|e| Error::service_safe(e, ClientIo))
92            .transpose()
93    }
94}
95
96impl Read for RequestBody {
97    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
98        let in_buf = self.fill_buf()?;
99        let len = usize::min(in_buf.len(), buf.len());
100        buf[..len].copy_from_slice(&in_buf[..len]);
101        self.consume(len);
102        Ok(len)
103    }
104}
105
106impl BufRead for RequestBody {
107    fn fill_buf(&mut self) -> io::Result<&[u8]> {
108        while self.cur.is_empty() {
109            match self.next_raw().map_err(io::Error::other)? {
110                Some(bytes) => self.cur = bytes,
111                None => break,
112            }
113        }
114
115        Ok(&self.cur)
116    }
117
118    fn consume(&mut self, amt: usize) {
119        self.cur.advance(amt)
120    }
121}
122
123pub enum BodyPart {
124    Frame(Frame<Bytes>),
125    Done,
126}
127
128/// The writer used for streaming response bodies of blocking endpoints.
129pub struct ResponseWriter {
130    sender: mpsc::Sender<BodyPart>,
131    handle: Handle,
132    buf: BytesMut,
133}
134
135impl ResponseWriter {
136    pub(crate) fn new(sender: mpsc::Sender<BodyPart>, handle: Handle) -> Self {
137        Self {
138            sender,
139            handle,
140            buf: BytesMut::new(),
141        }
142    }
143
144    /// Writes a block of [`Bytes`] to the response body.
145    ///
146    /// Compared to `ResponseWriter`'s [`Write`] implementation, this method can avoid some copies if the data is
147    /// already represented as a [`Bytes`] value.
148    pub fn send(&mut self, bytes: Bytes) -> Result<(), Error> {
149        self.send_inner(BodyPart::Frame(Frame::data(bytes)))
150    }
151
152    /// Writes the response's trailers.
153    ///
154    /// The body must be fully written before calling this method.
155    pub fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
156        self.send_inner(BodyPart::Frame(Frame::trailers(trailers)))
157    }
158
159    fn send_inner(&mut self, part: BodyPart) -> Result<(), Error> {
160        self.flush_shallow()
161            .map_err(|e| Error::service_safe(e, ClientIo))?;
162
163        Self::with_timeout(&self.handle, self.sender.feed(part))
164            .map_err(|e| Error::service_safe(e, ClientIo))?;
165
166        Ok(())
167    }
168
169    fn with_timeout<F, R, E>(
170        handle: &Handle,
171        future: F,
172    ) -> Result<R, Box<dyn error::Error + Sync + Send>>
173    where
174        F: Future<Output = Result<R, E>>,
175        E: Into<Box<dyn error::Error + Sync + Send>>,
176    {
177        handle
178            .block_on(async { time::timeout(IO_TIMEOUT, future).await })?
179            .map_err(Into::into)
180    }
181
182    fn flush_shallow(&mut self) -> Result<(), Box<dyn error::Error + Sync + Send>> {
183        if self.buf.is_empty() {
184            return Ok(());
185        }
186
187        Self::with_timeout(
188            &self.handle,
189            self.sender
190                .feed(BodyPart::Frame(Frame::data(self.buf.split().freeze()))),
191        )
192    }
193
194    pub(crate) fn finish(mut self) -> Result<(), Error> {
195        self.flush_shallow()
196            .map_err(|e| Error::service_safe(e, ClientIo))?;
197
198        Self::with_timeout(&self.handle, self.sender.send(BodyPart::Done))
199            .map_err(|e| Error::service_safe(e, ClientIo))?;
200
201        Ok(())
202    }
203}
204
205impl Write for ResponseWriter {
206    #[inline]
207    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
208        if self.buf.len() > 4096 {
209            self.flush_shallow().map_err(io::Error::other)?;
210        }
211
212        self.buf.extend_from_slice(buf);
213        Ok(buf.len())
214    }
215
216    fn flush(&mut self) -> io::Result<()> {
217        self.flush_shallow().map_err(io::Error::other)?;
218
219        Self::with_timeout(&self.handle, self.sender.flush()).map_err(io::Error::other)?;
220
221        Ok(())
222    }
223}