shadowsocks_service/local/http/
http_client.rs

1//! HTTP Client
2
3use std::{
4    borrow::Cow,
5    collections::VecDeque,
6    fmt::Debug,
7    future::Future,
8    io::{self, ErrorKind},
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12    time::{Duration, Instant},
13};
14
15use http::{HeaderValue, Method as HttpMethod, Uri, Version as HttpVersion, header::InvalidHeaderValue};
16use hyper::{
17    Request, Response,
18    body::{self, Body},
19    client::conn::{TrySendError, http1, http2},
20    http::uri::Scheme,
21    rt::{Sleep, Timer},
22};
23use log::{debug, error, trace};
24use lru_time_cache::LruCache;
25use pin_project::pin_project;
26use shadowsocks::relay::Address;
27use tokio::sync::Mutex;
28
29use crate::local::{context::ServiceContext, loadbalancing::PingBalancer, net::AutoProxyClientStream};
30
31use super::{
32    http_stream::ProxyHttpStream,
33    tokio_rt::{TokioExecutor, TokioIo},
34    utils::{check_keep_alive, connect_host, host_addr},
35};
36
37const CONNECTION_EXPIRE_DURATION: Duration = Duration::from_secs(20);
38
39/// HTTPClient API request errors
40#[derive(thiserror::Error, Debug)]
41pub enum HttpClientError {
42    /// Errors from hyper
43    #[error("{0}")]
44    Hyper(#[from] hyper::Error),
45    /// std::io::Error
46    #[error("{0}")]
47    Io(#[from] io::Error),
48    /// Errors from http
49    #[error("{0}")]
50    Http(#[from] http::Error),
51    /// Errors from http header
52    #[error("{0}")]
53    InvalidHeaderValue(#[from] InvalidHeaderValue),
54}
55
56#[allow(clippy::large_enum_variant)]
57#[derive(thiserror::Error, Debug)]
58enum SendRequestError<B> {
59    #[error("{0}")]
60    Http(#[from] http::Error),
61
62    #[error("{0}")]
63    TrySend(#[from] TrySendError<Request<B>>),
64}
65
66#[derive(Clone, Debug)]
67pub struct TokioTimer;
68
69impl Timer for TokioTimer {
70    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
71        Box::pin(TokioSleep {
72            inner: tokio::time::sleep(duration),
73        })
74    }
75
76    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
77        Box::pin(TokioSleep {
78            inner: tokio::time::sleep_until(deadline.into()),
79        })
80    }
81
82    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
83        if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
84            sleep.reset(new_deadline)
85        }
86    }
87}
88
89#[pin_project]
90pub(crate) struct TokioSleep {
91    #[pin]
92    pub(crate) inner: tokio::time::Sleep,
93}
94
95impl Future for TokioSleep {
96    type Output = ();
97
98    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
99        self.project().inner.poll(cx)
100    }
101}
102
103impl Sleep for TokioSleep {}
104
105impl TokioSleep {
106    pub fn reset(self: Pin<&mut Self>, deadline: Instant) {
107        self.project().inner.as_mut().reset(deadline.into());
108    }
109}
110
111/// HTTPClient, supporting HTTP/1.1 and H2, HTTPS.
112pub struct HttpClient<B> {
113    #[allow(clippy::type_complexity)]
114    cache_conn: Arc<Mutex<LruCache<Address, VecDeque<(HttpConnection<B>, Instant)>>>>,
115}
116
117impl<B> Clone for HttpClient<B> {
118    fn clone(&self) -> Self {
119        Self {
120            cache_conn: self.cache_conn.clone(),
121        }
122    }
123}
124
125impl<B> Default for HttpClient<B>
126where
127    B: Body + Send + Unpin + Debug + 'static,
128    B::Data: Send,
129    B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
130{
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl<B> HttpClient<B>
137where
138    B: Body + Send + Unpin + Debug + 'static,
139    B::Data: Send,
140    B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
141{
142    /// Create a new HttpClient
143    pub fn new() -> Self {
144        Self {
145            cache_conn: Arc::new(Mutex::new(LruCache::with_expiry_duration(CONNECTION_EXPIRE_DURATION))),
146        }
147    }
148
149    /// Make HTTP requests
150    #[inline]
151    pub async fn send_request(
152        &self,
153        context: Arc<ServiceContext>,
154        req: Request<B>,
155        balancer: Option<&PingBalancer>,
156    ) -> Result<Response<body::Incoming>, HttpClientError> {
157        let host = match host_addr(req.uri()) {
158            Some(h) => h,
159            None => panic!("URI missing host: {}", req.uri()),
160        };
161
162        // Set Host header if it was missing in the Request
163        let (mut req_parts, req_body) = req.into_parts();
164        if let Some(authority) = req_parts.uri.authority() {
165            let headers = &mut req_parts.headers;
166            if !headers.contains_key("Host") {
167                let uri = &req_parts.uri;
168                let host_value = if (uri.scheme_str() == Some("http")
169                    && matches!(authority.port_u16(), None | Some(80)))
170                    || (uri.scheme_str() == Some("https") && matches!(authority.port_u16(), None | Some(443)))
171                {
172                    HeaderValue::from_str(authority.host())?
173                } else {
174                    HeaderValue::from_str(authority.as_str())?
175                };
176
177                headers.insert("Host", host_value);
178            }
179        }
180        let mut req = Request::from_parts(req_parts, req_body);
181
182        // 1. Check if there is an available client
183        if let Some(c) = self.get_cached_connection(&host).await {
184            trace!("HTTP client for host: {} taken from cache", host);
185            match self.send_request_conn(host.clone(), c, req).await {
186                Ok(response) => return Ok(response),
187                Err(SendRequestError::TrySend(mut err)) => {
188                    if let Some(inner_req) = err.take_message() {
189                        req = inner_req;
190
191                        // If TrySendError, the connection is probably broken, we should make a new connection
192                        debug!(
193                            "failed to send request via cached connection to host: {}, error: {}. retry with a new connection",
194                            host,
195                            err.error()
196                        );
197                    } else {
198                        error!(
199                            "failed to send request via cached connection to host: {}, error: {}. no request to retry",
200                            host,
201                            err.error()
202                        );
203                        return Err(err.into_error().into());
204                    }
205                }
206                Err(SendRequestError::Http(err)) => {
207                    error!(
208                        "failed to send request via cached connection to host: {}, error: {}",
209                        host, err
210                    );
211                    return Err(err.into());
212                }
213            }
214        }
215
216        // 2. If no. Make a new connection
217        let scheme = match req.uri().scheme() {
218            Some(s) => s,
219            None => &Scheme::HTTP,
220        };
221
222        let domain = match host {
223            Address::DomainNameAddress(ref domain, _) => Cow::Borrowed(domain.as_str()),
224            Address::SocketAddress(ref saddr) => Cow::Owned(saddr.ip().to_string()),
225        };
226
227        let c = match HttpConnection::connect(context.clone(), scheme, host.clone(), &domain, balancer).await {
228            Ok(c) => c,
229            Err(err) => {
230                error!("failed to connect to host: {}, error: {}", host, err);
231                return Err(err.into());
232            }
233        };
234
235        match self.send_request_conn(host, c, req).await {
236            Ok(response) => Ok(response),
237            Err(SendRequestError::TrySend(err)) => Err(err.into_error().into()),
238            Err(SendRequestError::Http(err)) => Err(err.into()),
239        }
240    }
241
242    async fn get_cached_connection(&self, host: &Address) -> Option<HttpConnection<B>> {
243        if let Some(q) = self.cache_conn.lock().await.get_mut(host) {
244            while let Some((c, inst)) = q.pop_front() {
245                let now = Instant::now();
246                if now - inst >= CONNECTION_EXPIRE_DURATION {
247                    continue;
248                }
249                if c.is_closed() {
250                    continue;
251                }
252                return Some(c);
253            }
254        }
255        None
256    }
257
258    async fn send_request_conn(
259        &self,
260        host: Address,
261        mut c: HttpConnection<B>,
262        req: Request<B>,
263    ) -> Result<Response<body::Incoming>, SendRequestError<B>> {
264        trace!("HTTP making request to host: {}, request: {:?}", host, req);
265        let response = c.send_request(req).await?;
266        trace!("HTTP received response from host: {}, response: {:?}", host, response);
267
268        // Check keep-alive
269        if check_keep_alive(response.version(), response.headers(), false) {
270            trace!(
271                "HTTP connection keep-alive for host: {}, response: {:?}",
272                host, response
273            );
274            self.cache_conn
275                .lock()
276                .await
277                .entry(host)
278                .or_insert_with(VecDeque::new)
279                .push_back((c, Instant::now()));
280        }
281
282        Ok(response)
283    }
284}
285
286enum HttpConnection<B> {
287    Http1(http1::SendRequest<B>),
288    Http2(http2::SendRequest<B>),
289}
290
291impl<B> HttpConnection<B>
292where
293    B: Body + Send + Unpin + 'static,
294    B::Data: Send,
295    B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
296{
297    async fn connect(
298        context: Arc<ServiceContext>,
299        scheme: &Scheme,
300        host: Address,
301        domain: &str,
302        balancer: Option<&PingBalancer>,
303    ) -> io::Result<Self> {
304        if *scheme != Scheme::HTTP && *scheme != Scheme::HTTPS {
305            return Err(io::Error::new(ErrorKind::InvalidInput, "invalid scheme"));
306        }
307
308        let (stream, _) = connect_host(context, &host, balancer).await?;
309
310        if *scheme == Scheme::HTTP {
311            Self::connect_http_http1(scheme, host, stream).await
312        } else if *scheme == Scheme::HTTPS {
313            Self::connect_https(scheme, host, domain, stream).await
314        } else {
315            unreachable!()
316        }
317    }
318
319    async fn connect_http_http1(scheme: &Scheme, host: Address, stream: AutoProxyClientStream) -> io::Result<Self> {
320        trace!(
321            "HTTP making new HTTP/1.1 connection to host: {}, scheme: {}",
322            host, scheme
323        );
324
325        let stream = ProxyHttpStream::connect_http(stream);
326
327        // HTTP/1.x
328        let (send_request, connection) = match http1::Builder::new()
329            .preserve_header_case(true)
330            .title_case_headers(true)
331            .handshake(TokioIo::new(stream))
332            .await
333        {
334            Ok(s) => s,
335            Err(err) => return Err(io::Error::other(err)),
336        };
337
338        tokio::spawn(async move {
339            if let Err(err) = connection.await {
340                error!("HTTP/1.x connection to host: {} aborted with error: {}", host, err);
341            }
342        });
343
344        Ok(Self::Http1(send_request))
345    }
346
347    async fn connect_https(
348        scheme: &Scheme,
349        host: Address,
350        domain: &str,
351        stream: AutoProxyClientStream,
352    ) -> io::Result<Self> {
353        trace!("HTTP making new TLS connection to host: {}, scheme: {}", host, scheme);
354
355        // TLS handshake, check alpn for h2 support.
356        let stream = ProxyHttpStream::connect_https(stream, domain).await?;
357
358        if stream.negotiated_http2() {
359            // H2 connection
360            let (send_request, connection) = match http2::Builder::new(TokioExecutor)
361                .timer(TokioTimer)
362                .keep_alive_interval(Duration::from_secs(15))
363                .handshake(TokioIo::new(stream))
364                .await
365            {
366                Ok(s) => s,
367                Err(err) => return Err(io::Error::other(err)),
368            };
369
370            tokio::spawn(async move {
371                if let Err(err) = connection.await {
372                    error!("HTTP/2 TLS connection to host: {} aborted with error: {}", host, err);
373                }
374            });
375
376            Ok(Self::Http2(send_request))
377        } else {
378            // HTTP/1.x TLS
379            let (send_request, connection) = match http1::Builder::new()
380                .preserve_header_case(true)
381                .title_case_headers(true)
382                .handshake(TokioIo::new(stream))
383                .await
384            {
385                Ok(s) => s,
386                Err(err) => return Err(io::Error::other(err)),
387            };
388
389            tokio::spawn(async move {
390                if let Err(err) = connection.await {
391                    error!("HTTP/1.x TLS connection to host: {} aborted with error: {}", host, err);
392                }
393            });
394
395            Ok(Self::Http1(send_request))
396        }
397    }
398
399    #[inline]
400    pub async fn send_request(&mut self, mut req: Request<B>) -> Result<Response<body::Incoming>, SendRequestError<B>> {
401        match self {
402            Self::Http1(r) => {
403                if !matches!(
404                    req.version(),
405                    HttpVersion::HTTP_09 | HttpVersion::HTTP_10 | HttpVersion::HTTP_11
406                ) {
407                    trace!(
408                        "HTTP client changed Request.version to HTTP/1.1 from {:?}",
409                        req.version()
410                    );
411
412                    *req.version_mut() = HttpVersion::HTTP_11;
413                }
414
415                // Remove Scheme, Host part from URI
416                if req.method() != HttpMethod::CONNECT
417                    && (req.uri().scheme().is_some() || req.uri().authority().is_some())
418                {
419                    let mut builder = Uri::builder();
420                    match req.uri().path_and_query() {
421                        Some(path_and_query) => {
422                            builder = builder.path_and_query(path_and_query.as_str());
423                        }
424                        _ => {
425                            builder = builder.path_and_query("/");
426                        }
427                    }
428                    *(req.uri_mut()) = builder.build()?;
429                }
430
431                r.try_send_request(req).await.map_err(Into::into)
432            }
433            Self::Http2(r) => {
434                if !matches!(req.version(), HttpVersion::HTTP_2) {
435                    trace!("HTTP client changed Request.version to HTTP/2 from {:?}", req.version());
436
437                    *req.version_mut() = HttpVersion::HTTP_2;
438                }
439
440                r.try_send_request(req).await.map_err(Into::into)
441            }
442        }
443    }
444
445    pub fn is_closed(&self) -> bool {
446        match self {
447            Self::Http1(r) => r.is_closed(),
448            Self::Http2(r) => r.is_closed(),
449        }
450    }
451}