1use core::{
2 future::poll_fn,
3 marker::PhantomData,
4 net::SocketAddr,
5 pin::{Pin, pin},
6 task::Poll,
7 time::Duration,
8};
9
10use std::net::Shutdown;
11
12use tracing::trace;
13use xitca_io::io::{AsyncBufRead, AsyncBufWrite};
14use xitca_service::Service;
15use xitca_unsafe_collection::futures::SelectOutput;
16
17use crate::{
18 body::{Body, Frame},
19 bytes::{Bytes, BytesMut},
20 config::HttpServiceConfig,
21 date::DateTime,
22 h1::error::Error,
23 http::{StatusCode, response::Response},
24 util::timer::{KeepAlive, Timeout},
25};
26
27use super::{
28 body::{RequestBody, body},
29 io::{BufIo, SharedIo},
30 proto::{buf_write::H1BufWrite, context::Context, error::ProtoError},
31};
32
33type ExtRequest<B> = crate::http::Request<crate::http::RequestExt<B>>;
34
35pub struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> {
37 io: SharedIo<Io>,
38 timer: Timer<'a>,
39 ctx: Context<'a, D, H_LIMIT>,
40 service: &'a S,
41 _phantom: PhantomData<ReqB>,
42}
43
44impl<'a, Io, S, ReqB, ResB, BE, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize>
45 Dispatcher<'a, Io, S, ReqB, D, H_LIMIT, R_LIMIT, W_LIMIT>
46where
47 Io: AsyncBufRead + AsyncBufWrite + 'static,
48 S: Service<ExtRequest<ReqB>, Response = Response<ResB>>,
49 ReqB: From<RequestBody>,
50 ResB: Body<Data = Bytes, Error = BE>,
51 D: DateTime,
52{
53 pub async fn run(
54 io: Io,
55 addr: SocketAddr,
56 read_buf: BytesMut,
57 timer: Pin<&'a mut KeepAlive>,
58 config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
59 service: &'a S,
60 date: &'a D,
61 ) -> Result<(), Error<S::Error, BE>> {
62 let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> {
63 io: SharedIo::new(io),
64 timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
65 ctx: Context::with_addr(addr, date),
66 service,
67 _phantom: PhantomData,
68 };
69
70 let mut read_buf = read_buf;
71 let mut write_buf = BytesMut::new();
72
73 let res = loop {
74 let (res, r_buf, w_buf) = dispatcher._run(read_buf, write_buf).await;
75 read_buf = r_buf;
76 write_buf = w_buf;
77
78 if let Err(err) = res {
79 break Err(err);
80 }
81
82 let (res, w_buf) = write_buf.write(dispatcher.io.io()).await;
83 write_buf = w_buf;
84
85 res?;
86
87 if dispatcher.ctx.is_connection_closed() {
88 break Ok(());
89 }
90
91 dispatcher.timer.update(dispatcher.ctx.date().now());
92
93 match read_buf.read(dispatcher.io.io()).timeout(dispatcher.timer.get()).await {
94 Ok((res, r_buf)) => {
95 read_buf = r_buf;
96
97 if res? == 0 {
98 break Ok(());
99 }
100 }
101 Err(_) => break Err(dispatcher.timer.map_to_err()),
102 }
103 };
104
105 if let Err(err) = res {
106 handle_error(&mut dispatcher.ctx, &mut write_buf, err)?;
107 }
108
109 dispatcher.shutdown(write_buf).await
110 }
111
112 async fn _run(
113 &mut self,
114 mut read_buf: BytesMut,
115 mut write_buf: BytesMut,
116 ) -> (Result<(), Error<S::Error, BE>>, BytesMut, BytesMut) {
117 loop {
118 let (req, decoder) = match self.ctx.decode_head::<R_LIMIT>(&mut read_buf) {
119 Ok(Some(req)) => req,
120 Ok(None) => break,
121 Err(e) => return (Err(e.into()), read_buf, write_buf),
122 };
123
124 self.timer.reset_state();
125
126 let (wait_for_notify, body) = if decoder.is_eof() {
127 (false, RequestBody::default())
128 } else {
129 let body = body(
130 self.io.notifier(),
131 self.ctx.is_expect_header(),
132 R_LIMIT,
133 decoder,
134 read_buf.split(),
135 );
136
137 (true, body)
138 };
139
140 let req = req.map(|ext| ext.map_body(|_| ReqB::from(body)));
141
142 let (parts, body) = match self.service.call(req).await {
143 Ok(res) => res.into_parts(),
144 Err(e) => return (Err(Error::Service(e)), read_buf, write_buf),
145 };
146
147 let mut encoder = match self.ctx.encode_head(parts, &body, &mut write_buf) {
148 Ok(encoder) => encoder,
149 Err(e) => return (Err(e.into()), read_buf, write_buf),
150 };
151
152 {
156 let mut body = pin!(body);
157
158 let trailers = loop {
159 let res = poll_fn(|cx| match body.as_mut().poll_frame(cx) {
160 Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)),
161 Poll::Pending if write_buf.is_empty() => Poll::Pending,
162 Poll::Pending => Poll::Ready(SelectOutput::B(())),
163 })
164 .await;
165
166 let res = match res {
167 SelectOutput::A(Some(Ok(Frame::Data(bytes)))) => {
168 encoder.encode(bytes, &mut write_buf);
169 if write_buf.len() < W_LIMIT {
170 continue;
171 }
172 Ok(())
173 }
174 SelectOutput::A(Some(Ok(Frame::Trailers(trailers)))) => break Some(trailers),
175 SelectOutput::A(Some(Err(e))) => Err(Error::Body(e)),
176 SelectOutput::A(None) => break None,
177 SelectOutput::B(_) => Ok(()),
178 };
179
180 let (res_io, w_buf) = write_buf.write(self.io.io()).await;
181 write_buf = w_buf;
182
183 if let Some(e) = res.err().or_else(|| res_io.err().map(Error::Io)) {
184 return (Err(e), read_buf, write_buf);
185 }
186 };
187
188 encoder.encode_eof(trailers, &mut write_buf);
189 }
190
191 if wait_for_notify {
192 match self.io.wait().await {
193 Some(r_buf) => read_buf = r_buf,
194 None => {
195 self.ctx.set_close();
196 break;
197 }
198 }
199 }
200 }
201
202 (Ok(()), read_buf, write_buf)
203 }
204
205 #[cold]
206 #[inline(never)]
207 async fn shutdown(self, write_buf: BytesMut) -> Result<(), Error<S::Error, BE>> {
208 let io = self.io.into_io();
209 let (res, _) = write_buf.write(&io).await;
210 res?;
211 io.shutdown(Shutdown::Both).await.map_err(Into::into)
212 }
213}
214
215enum TimerState {
223 Idle,
224 Wait,
225 Throttle,
226}
227
228struct Timer<'a> {
229 timer: Pin<&'a mut KeepAlive>,
230 state: TimerState,
231 ka_dur: Duration,
232 req_dur: Duration,
233}
234
235impl<'a> Timer<'a> {
236 fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self {
237 Self {
238 timer,
239 state: TimerState::Idle,
240 ka_dur,
241 req_dur,
242 }
243 }
244
245 fn reset_state(&mut self) {
246 self.state = TimerState::Idle;
247 }
248
249 fn get(&mut self) -> Pin<&mut KeepAlive> {
250 self.timer.as_mut()
251 }
252
253 fn update(&mut self, now: tokio::time::Instant) {
255 let dur = match self.state {
256 TimerState::Idle => {
257 self.state = TimerState::Wait;
258 self.ka_dur
259 }
260 TimerState::Wait => {
261 self.state = TimerState::Throttle;
262 self.req_dur
263 }
264 TimerState::Throttle => return,
265 };
266 self.timer.as_mut().update(now + dur)
267 }
268
269 #[cold]
270 #[inline(never)]
271 fn map_to_err<SE, BE>(&self) -> Error<SE, BE> {
272 match self.state {
273 TimerState::Wait => Error::KeepAliveExpire,
274 TimerState::Throttle => Error::RequestTimeout,
275 TimerState::Idle => unreachable!(),
276 }
277 }
278}
279
280#[cold]
281#[inline(never)]
282fn handle_error<D, W, S, B, const H_LIMIT: usize>(
283 ctx: &mut Context<'_, D, H_LIMIT>,
284 buf: &mut W,
285 err: Error<S, B>,
286) -> Result<(), Error<S, B>>
287where
288 D: DateTime,
289 W: H1BufWrite,
290{
291 ctx.set_close();
292 match err {
293 Error::KeepAliveExpire => {
294 trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down")
295 }
296 e => {
297 let status = match e {
298 Error::RequestTimeout => StatusCode::REQUEST_TIMEOUT,
299 Error::Proto(ProtoError::HeaderTooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
300 Error::Proto(_) => StatusCode::BAD_REQUEST,
301 e => return Err(e),
302 };
303 let (parts, body) = Response::builder()
304 .status(status)
305 .body(crate::body::Empty::<Bytes>::new())
306 .unwrap()
307 .into_parts();
308 ctx.encode_head(parts, &body, buf)
309 .expect("request_error must be correct");
310 }
311 }
312 Ok(())
313}