witchcraft_server/blocking/
body.rs1use crate::body::ClientIo;
15use crate::server::RawBody;
16use bytes::{Buf, Bytes, BytesMut};
17use conjure_error::Error;
18use futures_channel::mpsc;
19use futures_util::{Future, SinkExt};
20use http::HeaderMap;
21use http_body::Frame;
22use http_body_util::BodyExt;
23use std::io::{BufRead, Read, Write};
24use std::time::Duration;
25use std::{error, io, mem};
26use tokio::runtime::Handle;
27use tokio::time;
28
29const IO_TIMEOUT: Duration = Duration::from_secs(60);
30
31pub struct RequestBody {
33 inner: RawBody,
34 handle: Handle,
35 cur: Bytes,
36 trailers: Option<HeaderMap>,
37}
38
39impl RequestBody {
40 pub(crate) fn new(inner: RawBody, handle: Handle) -> Self {
41 RequestBody {
42 inner,
43 handle,
44 cur: Bytes::new(),
45 trailers: None,
46 }
47 }
48
49 pub fn trailers(&mut self) -> Option<HeaderMap> {
53 self.trailers.take()
54 }
55
56 fn next_raw(&mut self) -> Result<Option<Bytes>, Box<dyn error::Error + Sync + Send>> {
57 loop {
58 let next = self
59 .handle
60 .block_on(async { time::timeout(IO_TIMEOUT, self.inner.frame()).await })?
61 .transpose()?;
62
63 let Some(next) = next else {
64 return Ok(None);
65 };
66
67 let next = match next.into_data() {
68 Ok(data) => return Ok(Some(data)),
69 Err(next) => next,
70 };
71
72 if let Ok(trailers) = next.into_trailers() {
73 match &mut self.trailers {
74 Some(base) => base.extend(trailers),
75 None => self.trailers = Some(trailers),
76 }
77 }
78 }
79 }
80}
81
82impl Iterator for RequestBody {
83 type Item = Result<Bytes, Error>;
84
85 fn next(&mut self) -> Option<Self::Item> {
86 if self.cur.has_remaining() {
87 return Some(Ok(mem::take(&mut self.cur)));
88 }
89
90 self.next_raw()
91 .map_err(|e| Error::service_safe(e, ClientIo))
92 .transpose()
93 }
94}
95
96impl Read for RequestBody {
97 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
98 let in_buf = self.fill_buf()?;
99 let len = usize::min(in_buf.len(), buf.len());
100 buf[..len].copy_from_slice(&in_buf[..len]);
101 self.consume(len);
102 Ok(len)
103 }
104}
105
106impl BufRead for RequestBody {
107 fn fill_buf(&mut self) -> io::Result<&[u8]> {
108 while self.cur.is_empty() {
109 match self.next_raw().map_err(io::Error::other)? {
110 Some(bytes) => self.cur = bytes,
111 None => break,
112 }
113 }
114
115 Ok(&self.cur)
116 }
117
118 fn consume(&mut self, amt: usize) {
119 self.cur.advance(amt)
120 }
121}
122
123pub enum BodyPart {
124 Frame(Frame<Bytes>),
125 Done,
126}
127
128pub struct ResponseWriter {
130 sender: mpsc::Sender<BodyPart>,
131 handle: Handle,
132 buf: BytesMut,
133}
134
135impl ResponseWriter {
136 pub(crate) fn new(sender: mpsc::Sender<BodyPart>, handle: Handle) -> Self {
137 Self {
138 sender,
139 handle,
140 buf: BytesMut::new(),
141 }
142 }
143
144 pub fn send(&mut self, bytes: Bytes) -> Result<(), Error> {
149 self.send_inner(BodyPart::Frame(Frame::data(bytes)))
150 }
151
152 pub fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
156 self.send_inner(BodyPart::Frame(Frame::trailers(trailers)))
157 }
158
159 fn send_inner(&mut self, part: BodyPart) -> Result<(), Error> {
160 self.flush_shallow()
161 .map_err(|e| Error::service_safe(e, ClientIo))?;
162
163 Self::with_timeout(&self.handle, self.sender.feed(part))
164 .map_err(|e| Error::service_safe(e, ClientIo))?;
165
166 Ok(())
167 }
168
169 fn with_timeout<F, R, E>(
170 handle: &Handle,
171 future: F,
172 ) -> Result<R, Box<dyn error::Error + Sync + Send>>
173 where
174 F: Future<Output = Result<R, E>>,
175 E: Into<Box<dyn error::Error + Sync + Send>>,
176 {
177 handle
178 .block_on(async { time::timeout(IO_TIMEOUT, future).await })?
179 .map_err(Into::into)
180 }
181
182 fn flush_shallow(&mut self) -> Result<(), Box<dyn error::Error + Sync + Send>> {
183 if self.buf.is_empty() {
184 return Ok(());
185 }
186
187 Self::with_timeout(
188 &self.handle,
189 self.sender
190 .feed(BodyPart::Frame(Frame::data(self.buf.split().freeze()))),
191 )
192 }
193
194 pub(crate) fn finish(mut self) -> Result<(), Error> {
195 self.flush_shallow()
196 .map_err(|e| Error::service_safe(e, ClientIo))?;
197
198 Self::with_timeout(&self.handle, self.sender.send(BodyPart::Done))
199 .map_err(|e| Error::service_safe(e, ClientIo))?;
200
201 Ok(())
202 }
203}
204
205impl Write for ResponseWriter {
206 #[inline]
207 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
208 if self.buf.len() > 4096 {
209 self.flush_shallow().map_err(io::Error::other)?;
210 }
211
212 self.buf.extend_from_slice(buf);
213 Ok(buf.len())
214 }
215
216 fn flush(&mut self) -> io::Result<()> {
217 self.flush_shallow().map_err(io::Error::other)?;
218
219 Self::with_timeout(&self.handle, self.sender.flush()).map_err(io::Error::other)?;
220
221 Ok(())
222 }
223}