Skip to main content

xitca_http/h1/
dispatcher.rs

1use core::{
2    future::poll_fn,
3    marker::PhantomData,
4    net::SocketAddr,
5    pin::{Pin, pin},
6    task::Poll,
7    time::Duration,
8};
9
10use std::net::Shutdown;
11
12use tracing::trace;
13use xitca_io::io::{AsyncBufRead, AsyncBufWrite};
14use xitca_service::Service;
15use xitca_unsafe_collection::futures::SelectOutput;
16
17use crate::{
18    body::{Body, Frame},
19    bytes::{Bytes, BytesMut},
20    config::HttpServiceConfig,
21    date::DateTime,
22    h1::error::Error,
23    http::{StatusCode, response::Response},
24    util::timer::{KeepAlive, Timeout},
25};
26
27use super::{
28    body::{RequestBody, body},
29    io::{BufIo, SharedIo},
30    proto::{buf_write::H1BufWrite, context::Context, error::ProtoError},
31};
32
33type ExtRequest<B> = crate::http::Request<crate::http::RequestExt<B>>;
34
35/// Http/1 dispatcher
36pub struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> {
37    io: SharedIo<Io>,
38    timer: Timer<'a>,
39    ctx: Context<'a, D, H_LIMIT>,
40    service: &'a S,
41    _phantom: PhantomData<ReqB>,
42}
43
44impl<'a, Io, S, ReqB, ResB, BE, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize>
45    Dispatcher<'a, Io, S, ReqB, D, H_LIMIT, R_LIMIT, W_LIMIT>
46where
47    Io: AsyncBufRead + AsyncBufWrite + 'static,
48    S: Service<ExtRequest<ReqB>, Response = Response<ResB>>,
49    ReqB: From<RequestBody>,
50    ResB: Body<Data = Bytes, Error = BE>,
51    D: DateTime,
52{
53    pub async fn run(
54        io: Io,
55        addr: SocketAddr,
56        read_buf: BytesMut,
57        timer: Pin<&'a mut KeepAlive>,
58        config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
59        service: &'a S,
60        date: &'a D,
61    ) -> Result<(), Error<S::Error, BE>> {
62        let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> {
63            io: SharedIo::new(io),
64            timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
65            ctx: Context::with_addr(addr, date),
66            service,
67            _phantom: PhantomData,
68        };
69
70        let mut read_buf = read_buf;
71        let mut write_buf = BytesMut::new();
72
73        let res = loop {
74            let (res, r_buf, w_buf) = dispatcher._run(read_buf, write_buf).await;
75            read_buf = r_buf;
76            write_buf = w_buf;
77
78            if let Err(err) = res {
79                break Err(err);
80            }
81
82            let (res, w_buf) = write_buf.write(dispatcher.io.io()).await;
83            write_buf = w_buf;
84
85            res?;
86
87            if dispatcher.ctx.is_connection_closed() {
88                break Ok(());
89            }
90
91            dispatcher.timer.update(dispatcher.ctx.date().now());
92
93            match read_buf.read(dispatcher.io.io()).timeout(dispatcher.timer.get()).await {
94                Ok((res, r_buf)) => {
95                    read_buf = r_buf;
96
97                    if res? == 0 {
98                        break Ok(());
99                    }
100                }
101                Err(_) => break Err(dispatcher.timer.map_to_err()),
102            }
103        };
104
105        if let Err(err) = res {
106            handle_error(&mut dispatcher.ctx, &mut write_buf, err)?;
107        }
108
109        dispatcher.shutdown(write_buf).await
110    }
111
112    async fn _run(
113        &mut self,
114        mut read_buf: BytesMut,
115        mut write_buf: BytesMut,
116    ) -> (Result<(), Error<S::Error, BE>>, BytesMut, BytesMut) {
117        loop {
118            let (req, decoder) = match self.ctx.decode_head::<R_LIMIT>(&mut read_buf) {
119                Ok(Some(req)) => req,
120                Ok(None) => break,
121                Err(e) => return (Err(e.into()), read_buf, write_buf),
122            };
123
124            self.timer.reset_state();
125
126            let (wait_for_notify, body) = if decoder.is_eof() {
127                (false, RequestBody::default())
128            } else {
129                let body = body(
130                    self.io.notifier(),
131                    self.ctx.is_expect_header(),
132                    R_LIMIT,
133                    decoder,
134                    read_buf.split(),
135                );
136
137                (true, body)
138            };
139
140            let req = req.map(|ext| ext.map_body(|_| ReqB::from(body)));
141
142            let (parts, body) = match self.service.call(req).await {
143                Ok(res) => res.into_parts(),
144                Err(e) => return (Err(Error::Service(e)), read_buf, write_buf),
145            };
146
147            let mut encoder = match self.ctx.encode_head(parts, &body, &mut write_buf) {
148                Ok(encoder) => encoder,
149                Err(e) => return (Err(e.into()), read_buf, write_buf),
150            };
151
152            // this block is necessary. ResB has to be dropped asap as it may hold ownership of
153            // Body type which if not dropped before Notifier::notify is called would prevent
154            // Notifier from waking up Notify.
155            {
156                let mut body = pin!(body);
157
158                let trailers = loop {
159                    let res = poll_fn(|cx| match body.as_mut().poll_frame(cx) {
160                        Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)),
161                        Poll::Pending if write_buf.is_empty() => Poll::Pending,
162                        Poll::Pending => Poll::Ready(SelectOutput::B(())),
163                    })
164                    .await;
165
166                    let res = match res {
167                        SelectOutput::A(Some(Ok(Frame::Data(bytes)))) => {
168                            encoder.encode(bytes, &mut write_buf);
169                            if write_buf.len() < W_LIMIT {
170                                continue;
171                            }
172                            Ok(())
173                        }
174                        SelectOutput::A(Some(Ok(Frame::Trailers(trailers)))) => break Some(trailers),
175                        SelectOutput::A(Some(Err(e))) => Err(Error::Body(e)),
176                        SelectOutput::A(None) => break None,
177                        SelectOutput::B(_) => Ok(()),
178                    };
179
180                    let (res_io, w_buf) = write_buf.write(self.io.io()).await;
181                    write_buf = w_buf;
182
183                    if let Some(e) = res.err().or_else(|| res_io.err().map(Error::Io)) {
184                        return (Err(e), read_buf, write_buf);
185                    }
186                };
187
188                encoder.encode_eof(trailers, &mut write_buf);
189            }
190
191            if wait_for_notify {
192                match self.io.wait().await {
193                    Some(r_buf) => read_buf = r_buf,
194                    None => {
195                        self.ctx.set_close();
196                        break;
197                    }
198                }
199            }
200        }
201
202        (Ok(()), read_buf, write_buf)
203    }
204
205    #[cold]
206    #[inline(never)]
207    async fn shutdown(self, write_buf: BytesMut) -> Result<(), Error<S::Error, BE>> {
208        let io = self.io.into_io();
209        let (res, _) = write_buf.write(&io).await;
210        res?;
211        io.shutdown(Shutdown::Both).await.map_err(Into::into)
212    }
213}
214
215// timer state is transformed in following order:
216//
217// Idle (expecting keep-alive duration)           <--
218//  |                                               |
219//  --> Wait (expecting request head duration)      |
220//       |                                          |
221//       --> Throttle (expecting manually set to Idle again)
222enum TimerState {
223    Idle,
224    Wait,
225    Throttle,
226}
227
228struct Timer<'a> {
229    timer: Pin<&'a mut KeepAlive>,
230    state: TimerState,
231    ka_dur: Duration,
232    req_dur: Duration,
233}
234
235impl<'a> Timer<'a> {
236    fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self {
237        Self {
238            timer,
239            state: TimerState::Idle,
240            ka_dur,
241            req_dur,
242        }
243    }
244
245    fn reset_state(&mut self) {
246        self.state = TimerState::Idle;
247    }
248
249    fn get(&mut self) -> Pin<&mut KeepAlive> {
250        self.timer.as_mut()
251    }
252
253    // update timer with a given base instant value. the final deadline is calculated base on it.
254    fn update(&mut self, now: tokio::time::Instant) {
255        let dur = match self.state {
256            TimerState::Idle => {
257                self.state = TimerState::Wait;
258                self.ka_dur
259            }
260            TimerState::Wait => {
261                self.state = TimerState::Throttle;
262                self.req_dur
263            }
264            TimerState::Throttle => return,
265        };
266        self.timer.as_mut().update(now + dur)
267    }
268
269    #[cold]
270    #[inline(never)]
271    fn map_to_err<SE, BE>(&self) -> Error<SE, BE> {
272        match self.state {
273            TimerState::Wait => Error::KeepAliveExpire,
274            TimerState::Throttle => Error::RequestTimeout,
275            TimerState::Idle => unreachable!(),
276        }
277    }
278}
279
280#[cold]
281#[inline(never)]
282fn handle_error<D, W, S, B, const H_LIMIT: usize>(
283    ctx: &mut Context<'_, D, H_LIMIT>,
284    buf: &mut W,
285    err: Error<S, B>,
286) -> Result<(), Error<S, B>>
287where
288    D: DateTime,
289    W: H1BufWrite,
290{
291    ctx.set_close();
292    match err {
293        Error::KeepAliveExpire => {
294            trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down")
295        }
296        e => {
297            let status = match e {
298                Error::RequestTimeout => StatusCode::REQUEST_TIMEOUT,
299                Error::Proto(ProtoError::HeaderTooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
300                Error::Proto(_) => StatusCode::BAD_REQUEST,
301                e => return Err(e),
302            };
303            let (parts, body) = Response::builder()
304                .status(status)
305                .body(crate::body::Empty::<Bytes>::new())
306                .unwrap()
307                .into_parts();
308            ctx.encode_head(parts, &body, buf)
309                .expect("request_error must be correct");
310        }
311    }
312    Ok(())
313}