1use crate::server::RawBody;
15use bytes::{Buf, Bytes, BytesMut};
16use conjure_error::{Error, ErrorCode, ErrorType};
17use conjure_object::Uuid;
18use futures_channel::mpsc;
19use futures_sink::Sink;
20use futures_util::{future, ready, SinkExt, Stream};
21use http::HeaderMap;
22use http_body::{Body, Frame};
23use pin_project::pin_project;
24use serde::ser::SerializeStruct;
25use serde::{Serialize, Serializer};
26use std::marker::PhantomPinned;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29use std::{io, mem};
30use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
31
32#[pin_project]
34pub struct RequestBody {
35 #[pin]
36 inner: RawBody,
37 cur: Bytes,
38 trailers: Option<HeaderMap>,
39 #[pin]
40 _p: PhantomPinned,
41}
42
43impl RequestBody {
44 pub(crate) fn new(inner: RawBody) -> Self {
45 RequestBody {
46 inner,
47 cur: Bytes::new(),
48 trailers: None,
49 _p: PhantomPinned,
50 }
51 }
52 pub fn trailers(self: Pin<&mut Self>) -> Option<HeaderMap> {
56 self.project().trailers.take()
57 }
58
59 fn poll_next_raw(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 ) -> Poll<Option<Result<Bytes, hyper::Error>>> {
63 let mut this = self.project();
64
65 loop {
66 let next = ready!(this.inner.as_mut().poll_frame(cx)).transpose()?;
67
68 let Some(next) = next else {
69 return Poll::Ready(None);
70 };
71
72 let next = match next.into_data() {
73 Ok(data) => return Poll::Ready(Some(Ok(data))),
74 Err(next) => next,
75 };
76
77 if let Ok(trailers) = next.into_trailers() {
78 match this.trailers {
79 Some(base) => base.extend(trailers),
80 None => *this.trailers = Some(trailers),
81 }
82 }
83 }
84 }
85}
86
87impl Stream for RequestBody {
88 type Item = Result<Bytes, Error>;
89
90 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 let this = self.as_mut().project();
92
93 if this.cur.has_remaining() {
94 return Poll::Ready(Some(Ok(mem::take(this.cur))));
95 }
96
97 self.poll_next_raw(cx)
98 .map_err(|e| Error::service_safe(e, ClientIo))
99 }
100}
101
102impl AsyncRead for RequestBody {
103 fn poll_read(
104 mut self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 buf: &mut ReadBuf<'_>,
107 ) -> Poll<io::Result<()>> {
108 let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
109 let len = usize::min(in_buf.len(), buf.remaining());
110 buf.put_slice(&in_buf[..len]);
111 self.consume(len);
112
113 Poll::Ready(Ok(()))
114 }
115}
116
117impl AsyncBufRead for RequestBody {
118 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
119 while self.cur.is_empty() {
120 match ready!(self.as_mut().poll_next_raw(cx))
121 .transpose()
122 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
123 {
124 Some(bytes) => *self.as_mut().project().cur = bytes,
125 None => break,
126 }
127 }
128
129 Poll::Ready(Ok(self.project().cur))
130 }
131
132 fn consume(self: Pin<&mut Self>, amt: usize) {
133 self.project().cur.advance(amt)
134 }
135}
136
137#[pin_project]
139pub struct ResponseWriter {
140 #[pin]
141 sender: mpsc::Sender<Frame<Bytes>>,
142 buf: BytesMut,
143 #[pin]
144 _p: PhantomPinned,
145}
146
147impl ResponseWriter {
148 pub(crate) fn new(sender: mpsc::Sender<Frame<Bytes>>) -> Self {
149 ResponseWriter {
150 sender,
151 buf: BytesMut::new(),
152 _p: PhantomPinned,
153 }
154 }
155
156 pub fn start_send_trailers(self: Pin<&mut Self>, trailers: HeaderMap) -> Result<(), Error> {
160 self.start_send_inner(Frame::trailers(trailers))
161 }
162
163 pub async fn send_trailers(mut self: Pin<&mut Self>, trailers: HeaderMap) -> Result<(), Error> {
167 future::poll_fn(|cx| self.as_mut().poll_flush_shallow(cx))
168 .await
169 .map_err(|e| Error::service_safe(e, ClientIo))?;
170
171 self.project()
172 .sender
173 .send(Frame::trailers(trailers))
174 .await
175 .map_err(|e| Error::service_safe(e, ClientIo))
176 }
177
178 pub(crate) async fn finish(mut self: Pin<&mut Self>) -> Result<(), Error> {
179 self.flush().await
180 }
181
182 fn start_send_inner(self: Pin<&mut Self>, item: Frame<Bytes>) -> Result<(), Error> {
183 let this = self.project();
184
185 assert!(this.buf.is_empty());
186 this.sender
187 .start_send(item)
188 .map_err(|e| Error::service_safe(e, ClientIo))
189 }
190
191 fn poll_flush_shallow(
192 self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 ) -> Poll<Result<(), mpsc::SendError>> {
195 let mut this = self.project();
196
197 if this.buf.is_empty() {
198 return Poll::Ready(Ok(()));
199 }
200
201 ready!(this.sender.as_mut().poll_ready(cx))?;
202 this.sender
203 .start_send(Frame::data(this.buf.split().freeze()))?;
204
205 Poll::Ready(Ok(()))
206 }
207}
208
209impl Sink<Bytes> for ResponseWriter {
210 type Error = Error;
211
212 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213 ready!(self.as_mut().poll_flush_shallow(cx))
214 .map_err(|e| Error::service_safe(e, ClientIo))?;
215
216 self.project()
217 .sender
218 .poll_ready(cx)
219 .map_err(|e| Error::service_safe(e, ClientIo))
220 }
221
222 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
223 self.start_send_inner(Frame::data(item))
224 }
225
226 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227 ready!(self.as_mut().poll_flush_shallow(cx))
228 .map_err(|e| Error::service_safe(e, ClientIo))?;
229
230 self.project()
231 .sender
232 .poll_flush(cx)
233 .map_err(|e| Error::service_safe(e, ClientIo))
234 }
235
236 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
237 ready!(self.as_mut().poll_flush_shallow(cx))
238 .map_err(|e| Error::service_safe(e, ClientIo))?;
239
240 self.project()
241 .sender
242 .poll_close(cx)
243 .map_err(|e| Error::service_safe(e, ClientIo))
244 }
245}
246
247impl AsyncWrite for ResponseWriter {
248 fn poll_write(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &[u8],
252 ) -> Poll<io::Result<usize>> {
253 if self.buf.len() > 4096 {
254 ready!(self.as_mut().poll_flush_shallow(cx))
255 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
256 }
257
258 self.project().buf.extend_from_slice(buf);
259 Poll::Ready(Ok(buf.len()))
260 }
261
262 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
263 ready!(self.as_mut().poll_flush_shallow(cx))
264 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
265
266 self.project()
267 .sender
268 .poll_flush(cx)
269 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
270 }
271
272 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
273 ready!(self.as_mut().poll_flush_shallow(cx))
274 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
275
276 self.project()
277 .sender
278 .poll_close(cx)
279 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
280 }
281}
282
283pub(crate) struct ClientIo;
284
285impl Serialize for ClientIo {
286 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
287 where
288 S: Serializer,
289 {
290 serializer.serialize_struct("ClientIo", 0)?.end()
291 }
292}
293
294impl ErrorType for ClientIo {
295 fn code(&self) -> ErrorCode {
296 ErrorCode::CustomClient
297 }
298
299 fn name(&self) -> &str {
300 "Witchcraft:ClientIo"
301 }
302
303 fn instance_id(&self) -> Option<Uuid> {
304 None
305 }
306
307 fn safe_args(&self) -> &'static [&'static str] {
308 &[]
309 }
310}
311
312#[cfg(test)]
313mod test {
314 use super::*;
315
316 #[test]
317 fn conjure_error_from_client_io() {
318 Error::service_safe("", ClientIo);
319 }
320}