ttpkit_http/client/
mod.rs

1//! HTTP client.
2
3mod connector;
4mod receiver;
5
6pub mod request;
7pub mod response;
8
9use std::{
10    future::Future,
11    io,
12    pin::Pin,
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use bytes::BytesMut;
18use futures::{FutureExt, StreamExt, ready};
19use tokio::io::AsyncWriteExt;
20
21use crate::{
22    Error, Version,
23    body::{Body, ChunkedStream},
24    connection::{Connection as HttpConnection, ConnectionReader, ConnectionWriter},
25    request::{RequestHeader, RequestHeaderEncoder},
26    url::Url,
27};
28
29use self::receiver::{ConnectionReaderJoinHandle, ResponseDecoder, ResponseDecoderOptions};
30
31pub use self::{
32    connector::{Connection, Connector},
33    request::OutgoingRequest,
34    response::IncomingResponse,
35};
36
37// TODO: add connection pooling (note: it should be a part of the client and
38// not the connector)
39// TODO: add request pipelining
40
41/// Builder for the HTTP client.
42pub struct ClientBuilder {
43    connection_timeout: Option<Duration>,
44    read_timeout: Option<Duration>,
45    write_timeout: Option<Duration>,
46    request_timeout: Option<Duration>,
47    decoder_options: ResponseDecoderOptions,
48}
49
50impl ClientBuilder {
51    /// Create a new builder.
52    #[inline]
53    const fn new() -> Self {
54        Self {
55            connection_timeout: Some(Duration::from_secs(60)),
56            read_timeout: Some(Duration::from_secs(60)),
57            write_timeout: Some(Duration::from_secs(60)),
58            request_timeout: Some(Duration::from_secs(60)),
59            decoder_options: ResponseDecoderOptions::new(),
60        }
61    }
62
63    /// Set connection timeout (default is 60 seconds).
64    #[inline]
65    pub const fn connection_timeout(mut self, timeout: Option<Duration>) -> Self {
66        self.connection_timeout = timeout;
67        self
68    }
69
70    /// Set read timeout (default is 60 seconds).
71    #[inline]
72    pub const fn read_timeout(mut self, timeout: Option<Duration>) -> Self {
73        self.read_timeout = timeout;
74        self
75    }
76
77    /// Set write timeout (default is 60 seconds).
78    #[inline]
79    pub const fn write_timeout(mut self, timeout: Option<Duration>) -> Self {
80        self.write_timeout = timeout;
81        self
82    }
83
84    /// Set request timeout (default is 60 seconds).
85    ///
86    /// Note: The request timeout does not include reading the body.
87    #[inline]
88    pub const fn request_timeout(mut self, timeout: Option<Duration>) -> Self {
89        self.request_timeout = timeout;
90        self
91    }
92
93    /// Set maximum line length for response header lines and chunked body
94    /// headers.
95    #[inline]
96    pub const fn max_line_length(mut self, max_length: Option<usize>) -> Self {
97        self.decoder_options = self.decoder_options.max_line_length(max_length);
98        self
99    }
100
101    /// Set maximum header field length.
102    #[inline]
103    pub const fn max_header_field_length(mut self, max_length: Option<usize>) -> Self {
104        self.decoder_options = self.decoder_options.max_header_field_length(max_length);
105        self
106    }
107
108    /// Set maximum number of lines for the response header.
109    #[inline]
110    pub const fn max_header_fields(mut self, max_fields: Option<usize>) -> Self {
111        self.decoder_options = self.decoder_options.max_header_fields(max_fields);
112        self
113    }
114
115    /// Build the client.
116    #[inline]
117    pub const fn build(self, connector: Connector) -> Client {
118        Client {
119            connector,
120            connection_timeout: self.connection_timeout,
121            read_timeout: self.read_timeout,
122            write_timeout: self.write_timeout,
123            request_timeout: self.request_timeout,
124            decoder: ResponseDecoder::new(self.decoder_options),
125        }
126    }
127}
128
129/// HTTP client.
130#[derive(Clone)]
131pub struct Client {
132    connector: Connector,
133    connection_timeout: Option<Duration>,
134    read_timeout: Option<Duration>,
135    write_timeout: Option<Duration>,
136    request_timeout: Option<Duration>,
137    decoder: ResponseDecoder,
138}
139
140impl Client {
141    /// Get a client builder.
142    #[inline]
143    pub const fn builder() -> ClientBuilder {
144        ClientBuilder::new()
145    }
146
147    /// Send a given request.
148    pub async fn request(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
149        let version = request.version();
150
151        let host = request.url().host().to_string();
152
153        let (mut builder, body) = request.into_builder();
154
155        builder = builder
156            .set_header_field(("Host", host))
157            .remove_header_field("Content-Length")
158            .remove_header_field("Transfer-Encoding");
159
160        let request = if let Some(size) = body.size() {
161            builder
162                .add_header_field(("Content-Length", size))
163                .body(body)
164        } else if version == Version::Version11 {
165            builder
166                .add_header_field(("Transfer-Encoding", "chunked"))
167                .body(Body::from_stream(ChunkedStream::new(body)))
168        } else {
169            return Err(Error::from_static_msg(
170                "body size must be known for HTTP/1.0 requests",
171            ));
172        };
173
174        // TODO: handle basic & digest auth
175        // TODO: handle redirects
176
177        self.send(request).await
178    }
179
180    /// Send a given request.
181    async fn send(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
182        let send = self.send_inner(request);
183
184        if let Some(timeout) = self.request_timeout {
185            tokio::time::timeout(timeout, send)
186                .await
187                .map_err(|_| Error::from_static_msg("request timeout"))?
188        } else {
189            send.await
190        }
191    }
192
193    /// Send a given request.
194    async fn send_inner(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
195        let (url, header, body) = request.deconstruct();
196
197        let (reader, writer) = self.connect(&url).await?.split();
198
199        let mut writer = HttpRequestWriter::new(writer);
200        let mut reader = HttpResponseReader::new(reader, self.decoder);
201
202        writer.write_header(&header).await?;
203
204        if header.get_expect_continue() {
205            let (mut response, r) = reader.read_response().await?;
206
207            let status = response.status_code();
208
209            if status == 100 {
210                reader = r
211                    .await
212                    .ok_or_else(|| Error::from_static_msg("connection lost"))?;
213            } else {
214                // note: it'd be really unusual to receive 101 Switching
215                // Protocols when 100 Continue is expected but let's assume it
216                // can happen
217                if status == 101 {
218                    let upgraded = r
219                        .await
220                        .ok_or_else(|| Error::from_static_msg("connection lost"))?
221                        .into_inner()
222                        .join(writer.into_inner())
223                        .upgrade();
224
225                    response = response.with_upgraded_connection(upgraded);
226                }
227
228                return Ok(response);
229            }
230        }
231
232        writer.write_body(body).await?;
233
234        let (mut response, r) = reader.read_response().await?;
235
236        if response.status_code() == 101 {
237            let upgraded = r
238                .await
239                .ok_or_else(|| Error::from_static_msg("connection lost"))?
240                .into_inner()
241                .join(writer.into_inner())
242                .upgrade();
243
244            response = response.with_upgraded_connection(upgraded);
245        }
246
247        Ok(response)
248    }
249
250    /// Connect to a given server.
251    async fn connect(&self, url: &Url) -> Result<HttpConnection<Connection>, Error> {
252        let connect = self.connector.connect(url);
253
254        let connection = if let Some(timeout) = self.connection_timeout {
255            tokio::time::timeout(timeout, connect)
256                .await
257                .map_err(|_| Error::from_static_msg("connection timeout"))??
258        } else {
259            connect.await?
260        };
261
262        let res = HttpConnection::builder()
263            .read_timeout(self.read_timeout)
264            .write_timeout(self.write_timeout)
265            .build(connection);
266
267        Ok(res)
268    }
269}
270
271/// Request header extensions.
272trait RequestHeaderExt {
273    /// Check if there is the 100-continue expectation set.
274    fn get_expect_continue(&self) -> bool;
275}
276
277impl RequestHeaderExt for RequestHeader {
278    fn get_expect_continue(&self) -> bool {
279        if let Some(expect) = self.get_header_field_value("expect") {
280            expect
281                .split(|&b| b == b',')
282                .map(|exp| exp.trim_ascii())
283                .filter(|exp| !exp.is_empty())
284                .any(|exp| exp.eq_ignore_ascii_case(b"100-continue"))
285        } else {
286            false
287        }
288    }
289}
290
291/// Helper struct for sending HTTP requests.
292struct HttpRequestWriter {
293    buffer: BytesMut,
294    header_encoder: RequestHeaderEncoder,
295    inner: ConnectionWriter<Connection>,
296}
297
298impl HttpRequestWriter {
299    /// Create a new writer.
300    fn new(writer: ConnectionWriter<Connection>) -> Self {
301        Self {
302            buffer: BytesMut::new(),
303            header_encoder: RequestHeaderEncoder::new(),
304            inner: writer,
305        }
306    }
307
308    /// Write a given header into the underlying connection.
309    async fn write_header(&mut self, header: &RequestHeader) -> io::Result<()> {
310        self.header_encoder.encode(header, &mut self.buffer);
311
312        self.inner.write_all(&self.buffer.split()).await?;
313        self.inner.flush().await?;
314
315        Ok(())
316    }
317
318    /// Writer a given body into the underlying connection.
319    async fn write_body(&mut self, mut body: Body) -> io::Result<()> {
320        while let Some(chunk) = body.next().await.transpose()? {
321            self.inner.write_all(&chunk).await?;
322        }
323
324        self.inner.flush().await
325    }
326
327    /// Take the underlying connection.
328    fn into_inner(self) -> ConnectionWriter<Connection> {
329        self.inner
330    }
331}
332
333/// Helper struct for reading HTTP responses.
334struct HttpResponseReader {
335    reader: ConnectionReader<Connection>,
336    decoder: ResponseDecoder,
337}
338
339impl HttpResponseReader {
340    /// Create a new response reader.
341    fn new(reader: ConnectionReader<Connection>, decoder: ResponseDecoder) -> Self {
342        Self { reader, decoder }
343    }
344
345    /// Read a response from the underlying connection.
346    async fn read_response(self) -> Result<(IncomingResponse, FutureHttpResponseReader), Error> {
347        let (response, reader) = self.decoder.decode(self.reader).await?;
348
349        let reader = FutureHttpResponseReader {
350            inner: reader,
351            decoder: self.decoder,
352        };
353
354        Ok((response, reader))
355    }
356
357    /// Take the underlying connection.
358    fn into_inner(self) -> ConnectionReader<Connection> {
359        self.reader
360    }
361}
362
363/// Future HTTP response reader.
364struct FutureHttpResponseReader {
365    inner: ConnectionReaderJoinHandle<Connection>,
366    decoder: ResponseDecoder,
367}
368
369impl Future for FutureHttpResponseReader {
370    type Output = Option<HttpResponseReader>;
371
372    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
373        let res = ready!(self.inner.poll_unpin(cx))
374            .ok()
375            .flatten()
376            .map(|reader| HttpResponseReader::new(reader, self.decoder));
377
378        Poll::Ready(res)
379    }
380}