witchcraft_server/
body.rs

1// Copyright 2022 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::server::RawBody;
15use bytes::{Buf, Bytes, BytesMut};
16use conjure_error::{Error, ErrorCode, ErrorType};
17use conjure_object::Uuid;
18use futures_channel::mpsc;
19use futures_sink::Sink;
20use futures_util::{future, ready, SinkExt, Stream};
21use http::HeaderMap;
22use http_body::{Body, Frame};
23use pin_project::pin_project;
24use serde::ser::SerializeStruct;
25use serde::{Serialize, Serializer};
26use std::marker::PhantomPinned;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29use std::{io, mem};
30use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
31
32/// A streaming request body.
33#[pin_project]
34pub struct RequestBody {
35    #[pin]
36    inner: RawBody,
37    cur: Bytes,
38    trailers: Option<HeaderMap>,
39    #[pin]
40    _p: PhantomPinned,
41}
42
43impl RequestBody {
44    pub(crate) fn new(inner: RawBody) -> Self {
45        RequestBody {
46            inner,
47            cur: Bytes::new(),
48            trailers: None,
49            _p: PhantomPinned,
50        }
51    }
52    /// Returns the request's trailers, if any are present.
53    ///
54    /// The body must have been completely read before this is called.
55    pub fn trailers(self: Pin<&mut Self>) -> Option<HeaderMap> {
56        self.project().trailers.take()
57    }
58
59    fn poll_next_raw(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62    ) -> Poll<Option<Result<Bytes, hyper::Error>>> {
63        let mut this = self.project();
64
65        loop {
66            let next = ready!(this.inner.as_mut().poll_frame(cx)).transpose()?;
67
68            let Some(next) = next else {
69                return Poll::Ready(None);
70            };
71
72            let next = match next.into_data() {
73                Ok(data) => return Poll::Ready(Some(Ok(data))),
74                Err(next) => next,
75            };
76
77            if let Ok(trailers) = next.into_trailers() {
78                match this.trailers {
79                    Some(base) => base.extend(trailers),
80                    None => *this.trailers = Some(trailers),
81                }
82            }
83        }
84    }
85}
86
87impl Stream for RequestBody {
88    type Item = Result<Bytes, Error>;
89
90    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91        let this = self.as_mut().project();
92
93        if this.cur.has_remaining() {
94            return Poll::Ready(Some(Ok(mem::take(this.cur))));
95        }
96
97        self.poll_next_raw(cx)
98            .map_err(|e| Error::service_safe(e, ClientIo))
99    }
100}
101
102impl AsyncRead for RequestBody {
103    fn poll_read(
104        mut self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &mut ReadBuf<'_>,
107    ) -> Poll<io::Result<()>> {
108        let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
109        let len = usize::min(in_buf.len(), buf.remaining());
110        buf.put_slice(&in_buf[..len]);
111        self.consume(len);
112
113        Poll::Ready(Ok(()))
114    }
115}
116
117impl AsyncBufRead for RequestBody {
118    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
119        while self.cur.is_empty() {
120            match ready!(self.as_mut().poll_next_raw(cx))
121                .transpose()
122                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
123            {
124                Some(bytes) => *self.as_mut().project().cur = bytes,
125                None => break,
126            }
127        }
128
129        Poll::Ready(Ok(self.project().cur))
130    }
131
132    fn consume(self: Pin<&mut Self>, amt: usize) {
133        self.project().cur.advance(amt)
134    }
135}
136
137/// The writer used for streaming response bodies.
138#[pin_project]
139pub struct ResponseWriter {
140    #[pin]
141    sender: mpsc::Sender<Frame<Bytes>>,
142    buf: BytesMut,
143    #[pin]
144    _p: PhantomPinned,
145}
146
147impl ResponseWriter {
148    pub(crate) fn new(sender: mpsc::Sender<Frame<Bytes>>) -> Self {
149        ResponseWriter {
150            sender,
151            buf: BytesMut::new(),
152            _p: PhantomPinned,
153        }
154    }
155
156    /// Like [`Sink::start_send`] except that it sends the response's trailers.
157    ///
158    /// The body must be fully written before calling this method.
159    pub fn start_send_trailers(self: Pin<&mut Self>, trailers: HeaderMap) -> Result<(), Error> {
160        self.start_send_inner(Frame::trailers(trailers))
161    }
162
163    /// Like [`SinkExt::send`] except that it sends the response's trailers.
164    ///
165    /// The body must be fully written before calling this method.
166    pub async fn send_trailers(mut self: Pin<&mut Self>, trailers: HeaderMap) -> Result<(), Error> {
167        future::poll_fn(|cx| self.as_mut().poll_flush_shallow(cx))
168            .await
169            .map_err(|e| Error::service_safe(e, ClientIo))?;
170
171        self.project()
172            .sender
173            .send(Frame::trailers(trailers))
174            .await
175            .map_err(|e| Error::service_safe(e, ClientIo))
176    }
177
178    pub(crate) async fn finish(mut self: Pin<&mut Self>) -> Result<(), Error> {
179        self.flush().await
180    }
181
182    fn start_send_inner(self: Pin<&mut Self>, item: Frame<Bytes>) -> Result<(), Error> {
183        let this = self.project();
184
185        assert!(this.buf.is_empty());
186        this.sender
187            .start_send(item)
188            .map_err(|e| Error::service_safe(e, ClientIo))
189    }
190
191    fn poll_flush_shallow(
192        self: Pin<&mut Self>,
193        cx: &mut Context<'_>,
194    ) -> Poll<Result<(), mpsc::SendError>> {
195        let mut this = self.project();
196
197        if this.buf.is_empty() {
198            return Poll::Ready(Ok(()));
199        }
200
201        ready!(this.sender.as_mut().poll_ready(cx))?;
202        this.sender
203            .start_send(Frame::data(this.buf.split().freeze()))?;
204
205        Poll::Ready(Ok(()))
206    }
207}
208
209impl Sink<Bytes> for ResponseWriter {
210    type Error = Error;
211
212    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213        ready!(self.as_mut().poll_flush_shallow(cx))
214            .map_err(|e| Error::service_safe(e, ClientIo))?;
215
216        self.project()
217            .sender
218            .poll_ready(cx)
219            .map_err(|e| Error::service_safe(e, ClientIo))
220    }
221
222    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
223        self.start_send_inner(Frame::data(item))
224    }
225
226    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227        ready!(self.as_mut().poll_flush_shallow(cx))
228            .map_err(|e| Error::service_safe(e, ClientIo))?;
229
230        self.project()
231            .sender
232            .poll_flush(cx)
233            .map_err(|e| Error::service_safe(e, ClientIo))
234    }
235
236    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
237        ready!(self.as_mut().poll_flush_shallow(cx))
238            .map_err(|e| Error::service_safe(e, ClientIo))?;
239
240        self.project()
241            .sender
242            .poll_close(cx)
243            .map_err(|e| Error::service_safe(e, ClientIo))
244    }
245}
246
247impl AsyncWrite for ResponseWriter {
248    fn poll_write(
249        mut self: Pin<&mut Self>,
250        cx: &mut Context<'_>,
251        buf: &[u8],
252    ) -> Poll<io::Result<usize>> {
253        if self.buf.len() > 4096 {
254            ready!(self.as_mut().poll_flush_shallow(cx))
255                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
256        }
257
258        self.project().buf.extend_from_slice(buf);
259        Poll::Ready(Ok(buf.len()))
260    }
261
262    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
263        ready!(self.as_mut().poll_flush_shallow(cx))
264            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
265
266        self.project()
267            .sender
268            .poll_flush(cx)
269            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
270    }
271
272    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
273        ready!(self.as_mut().poll_flush_shallow(cx))
274            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
275
276        self.project()
277            .sender
278            .poll_close(cx)
279            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
280    }
281}
282
283pub(crate) struct ClientIo;
284
285impl Serialize for ClientIo {
286    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
287    where
288        S: Serializer,
289    {
290        serializer.serialize_struct("ClientIo", 0)?.end()
291    }
292}
293
294impl ErrorType for ClientIo {
295    fn code(&self) -> ErrorCode {
296        ErrorCode::CustomClient
297    }
298
299    fn name(&self) -> &str {
300        "Witchcraft:ClientIo"
301    }
302
303    fn instance_id(&self) -> Option<Uuid> {
304        None
305    }
306
307    fn safe_args(&self) -> &'static [&'static str] {
308        &[]
309    }
310}
311
312#[cfg(test)]
313mod test {
314    use super::*;
315
316    #[test]
317    fn conjure_error_from_client_io() {
318        Error::service_safe("", ClientIo);
319    }
320}