1use std::{
4 borrow::Cow,
5 collections::VecDeque,
6 fmt::Debug,
7 future::Future,
8 io::{self, ErrorKind},
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12 time::{Duration, Instant},
13};
14
15use http::{HeaderValue, Method as HttpMethod, Uri, Version as HttpVersion, header::InvalidHeaderValue};
16use hyper::{
17 Request, Response,
18 body::{self, Body},
19 client::conn::{TrySendError, http1, http2},
20 http::uri::Scheme,
21 rt::{Sleep, Timer},
22};
23use log::{debug, error, trace};
24use lru_time_cache::LruCache;
25use pin_project::pin_project;
26use shadowsocks::relay::Address;
27use tokio::sync::Mutex;
28
29use crate::local::{context::ServiceContext, loadbalancing::PingBalancer, net::AutoProxyClientStream};
30
31use super::{
32 http_stream::ProxyHttpStream,
33 tokio_rt::{TokioExecutor, TokioIo},
34 utils::{check_keep_alive, connect_host, host_addr},
35};
36
37const CONNECTION_EXPIRE_DURATION: Duration = Duration::from_secs(20);
38
39#[derive(thiserror::Error, Debug)]
41pub enum HttpClientError {
42 #[error("{0}")]
44 Hyper(#[from] hyper::Error),
45 #[error("{0}")]
47 Io(#[from] io::Error),
48 #[error("{0}")]
50 Http(#[from] http::Error),
51 #[error("{0}")]
53 InvalidHeaderValue(#[from] InvalidHeaderValue),
54}
55
56#[allow(clippy::large_enum_variant)]
57#[derive(thiserror::Error, Debug)]
58enum SendRequestError<B> {
59 #[error("{0}")]
60 Http(#[from] http::Error),
61
62 #[error("{0}")]
63 TrySend(#[from] TrySendError<Request<B>>),
64}
65
66#[derive(Clone, Debug)]
67pub struct TokioTimer;
68
69impl Timer for TokioTimer {
70 fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
71 Box::pin(TokioSleep {
72 inner: tokio::time::sleep(duration),
73 })
74 }
75
76 fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
77 Box::pin(TokioSleep {
78 inner: tokio::time::sleep_until(deadline.into()),
79 })
80 }
81
82 fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
83 if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
84 sleep.reset(new_deadline)
85 }
86 }
87}
88
89#[pin_project]
90pub(crate) struct TokioSleep {
91 #[pin]
92 pub(crate) inner: tokio::time::Sleep,
93}
94
95impl Future for TokioSleep {
96 type Output = ();
97
98 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
99 self.project().inner.poll(cx)
100 }
101}
102
103impl Sleep for TokioSleep {}
104
105impl TokioSleep {
106 pub fn reset(self: Pin<&mut Self>, deadline: Instant) {
107 self.project().inner.as_mut().reset(deadline.into());
108 }
109}
110
111pub struct HttpClient<B> {
113 #[allow(clippy::type_complexity)]
114 cache_conn: Arc<Mutex<LruCache<Address, VecDeque<(HttpConnection<B>, Instant)>>>>,
115}
116
117impl<B> Clone for HttpClient<B> {
118 fn clone(&self) -> Self {
119 Self {
120 cache_conn: self.cache_conn.clone(),
121 }
122 }
123}
124
125impl<B> Default for HttpClient<B>
126where
127 B: Body + Send + Unpin + Debug + 'static,
128 B::Data: Send,
129 B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
130{
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl<B> HttpClient<B>
137where
138 B: Body + Send + Unpin + Debug + 'static,
139 B::Data: Send,
140 B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
141{
142 pub fn new() -> Self {
144 Self {
145 cache_conn: Arc::new(Mutex::new(LruCache::with_expiry_duration(CONNECTION_EXPIRE_DURATION))),
146 }
147 }
148
149 #[inline]
151 pub async fn send_request(
152 &self,
153 context: Arc<ServiceContext>,
154 req: Request<B>,
155 balancer: Option<&PingBalancer>,
156 ) -> Result<Response<body::Incoming>, HttpClientError> {
157 let host = match host_addr(req.uri()) {
158 Some(h) => h,
159 None => panic!("URI missing host: {}", req.uri()),
160 };
161
162 let (mut req_parts, req_body) = req.into_parts();
164 if let Some(authority) = req_parts.uri.authority() {
165 let headers = &mut req_parts.headers;
166 if !headers.contains_key("Host") {
167 let uri = &req_parts.uri;
168 let host_value = if (uri.scheme_str() == Some("http")
169 && matches!(authority.port_u16(), None | Some(80)))
170 || (uri.scheme_str() == Some("https") && matches!(authority.port_u16(), None | Some(443)))
171 {
172 HeaderValue::from_str(authority.host())?
173 } else {
174 HeaderValue::from_str(authority.as_str())?
175 };
176
177 headers.insert("Host", host_value);
178 }
179 }
180 let mut req = Request::from_parts(req_parts, req_body);
181
182 if let Some(c) = self.get_cached_connection(&host).await {
184 trace!("HTTP client for host: {} taken from cache", host);
185 match self.send_request_conn(host.clone(), c, req).await {
186 Ok(response) => return Ok(response),
187 Err(SendRequestError::TrySend(mut err)) => {
188 if let Some(inner_req) = err.take_message() {
189 req = inner_req;
190
191 debug!(
193 "failed to send request via cached connection to host: {}, error: {}. retry with a new connection",
194 host,
195 err.error()
196 );
197 } else {
198 error!(
199 "failed to send request via cached connection to host: {}, error: {}. no request to retry",
200 host,
201 err.error()
202 );
203 return Err(err.into_error().into());
204 }
205 }
206 Err(SendRequestError::Http(err)) => {
207 error!(
208 "failed to send request via cached connection to host: {}, error: {}",
209 host, err
210 );
211 return Err(err.into());
212 }
213 }
214 }
215
216 let scheme = match req.uri().scheme() {
218 Some(s) => s,
219 None => &Scheme::HTTP,
220 };
221
222 let domain = match host {
223 Address::DomainNameAddress(ref domain, _) => Cow::Borrowed(domain.as_str()),
224 Address::SocketAddress(ref saddr) => Cow::Owned(saddr.ip().to_string()),
225 };
226
227 let c = match HttpConnection::connect(context.clone(), scheme, host.clone(), &domain, balancer).await {
228 Ok(c) => c,
229 Err(err) => {
230 error!("failed to connect to host: {}, error: {}", host, err);
231 return Err(err.into());
232 }
233 };
234
235 match self.send_request_conn(host, c, req).await {
236 Ok(response) => Ok(response),
237 Err(SendRequestError::TrySend(err)) => Err(err.into_error().into()),
238 Err(SendRequestError::Http(err)) => Err(err.into()),
239 }
240 }
241
242 async fn get_cached_connection(&self, host: &Address) -> Option<HttpConnection<B>> {
243 if let Some(q) = self.cache_conn.lock().await.get_mut(host) {
244 while let Some((c, inst)) = q.pop_front() {
245 let now = Instant::now();
246 if now - inst >= CONNECTION_EXPIRE_DURATION {
247 continue;
248 }
249 if c.is_closed() {
250 continue;
251 }
252 return Some(c);
253 }
254 }
255 None
256 }
257
258 async fn send_request_conn(
259 &self,
260 host: Address,
261 mut c: HttpConnection<B>,
262 req: Request<B>,
263 ) -> Result<Response<body::Incoming>, SendRequestError<B>> {
264 trace!("HTTP making request to host: {}, request: {:?}", host, req);
265 let response = c.send_request(req).await?;
266 trace!("HTTP received response from host: {}, response: {:?}", host, response);
267
268 if check_keep_alive(response.version(), response.headers(), false) {
270 trace!(
271 "HTTP connection keep-alive for host: {}, response: {:?}",
272 host, response
273 );
274 self.cache_conn
275 .lock()
276 .await
277 .entry(host)
278 .or_insert_with(VecDeque::new)
279 .push_back((c, Instant::now()));
280 }
281
282 Ok(response)
283 }
284}
285
286enum HttpConnection<B> {
287 Http1(http1::SendRequest<B>),
288 Http2(http2::SendRequest<B>),
289}
290
291impl<B> HttpConnection<B>
292where
293 B: Body + Send + Unpin + 'static,
294 B::Data: Send,
295 B::Error: Into<Box<dyn ::std::error::Error + Send + Sync>>,
296{
297 async fn connect(
298 context: Arc<ServiceContext>,
299 scheme: &Scheme,
300 host: Address,
301 domain: &str,
302 balancer: Option<&PingBalancer>,
303 ) -> io::Result<Self> {
304 if *scheme != Scheme::HTTP && *scheme != Scheme::HTTPS {
305 return Err(io::Error::new(ErrorKind::InvalidInput, "invalid scheme"));
306 }
307
308 let (stream, _) = connect_host(context, &host, balancer).await?;
309
310 if *scheme == Scheme::HTTP {
311 Self::connect_http_http1(scheme, host, stream).await
312 } else if *scheme == Scheme::HTTPS {
313 Self::connect_https(scheme, host, domain, stream).await
314 } else {
315 unreachable!()
316 }
317 }
318
319 async fn connect_http_http1(scheme: &Scheme, host: Address, stream: AutoProxyClientStream) -> io::Result<Self> {
320 trace!(
321 "HTTP making new HTTP/1.1 connection to host: {}, scheme: {}",
322 host, scheme
323 );
324
325 let stream = ProxyHttpStream::connect_http(stream);
326
327 let (send_request, connection) = match http1::Builder::new()
329 .preserve_header_case(true)
330 .title_case_headers(true)
331 .handshake(TokioIo::new(stream))
332 .await
333 {
334 Ok(s) => s,
335 Err(err) => return Err(io::Error::other(err)),
336 };
337
338 tokio::spawn(async move {
339 if let Err(err) = connection.await {
340 error!("HTTP/1.x connection to host: {} aborted with error: {}", host, err);
341 }
342 });
343
344 Ok(Self::Http1(send_request))
345 }
346
347 async fn connect_https(
348 scheme: &Scheme,
349 host: Address,
350 domain: &str,
351 stream: AutoProxyClientStream,
352 ) -> io::Result<Self> {
353 trace!("HTTP making new TLS connection to host: {}, scheme: {}", host, scheme);
354
355 let stream = ProxyHttpStream::connect_https(stream, domain).await?;
357
358 if stream.negotiated_http2() {
359 let (send_request, connection) = match http2::Builder::new(TokioExecutor)
361 .timer(TokioTimer)
362 .keep_alive_interval(Duration::from_secs(15))
363 .handshake(TokioIo::new(stream))
364 .await
365 {
366 Ok(s) => s,
367 Err(err) => return Err(io::Error::other(err)),
368 };
369
370 tokio::spawn(async move {
371 if let Err(err) = connection.await {
372 error!("HTTP/2 TLS connection to host: {} aborted with error: {}", host, err);
373 }
374 });
375
376 Ok(Self::Http2(send_request))
377 } else {
378 let (send_request, connection) = match http1::Builder::new()
380 .preserve_header_case(true)
381 .title_case_headers(true)
382 .handshake(TokioIo::new(stream))
383 .await
384 {
385 Ok(s) => s,
386 Err(err) => return Err(io::Error::other(err)),
387 };
388
389 tokio::spawn(async move {
390 if let Err(err) = connection.await {
391 error!("HTTP/1.x TLS connection to host: {} aborted with error: {}", host, err);
392 }
393 });
394
395 Ok(Self::Http1(send_request))
396 }
397 }
398
399 #[inline]
400 pub async fn send_request(&mut self, mut req: Request<B>) -> Result<Response<body::Incoming>, SendRequestError<B>> {
401 match self {
402 Self::Http1(r) => {
403 if !matches!(
404 req.version(),
405 HttpVersion::HTTP_09 | HttpVersion::HTTP_10 | HttpVersion::HTTP_11
406 ) {
407 trace!(
408 "HTTP client changed Request.version to HTTP/1.1 from {:?}",
409 req.version()
410 );
411
412 *req.version_mut() = HttpVersion::HTTP_11;
413 }
414
415 if req.method() != HttpMethod::CONNECT
417 && (req.uri().scheme().is_some() || req.uri().authority().is_some())
418 {
419 let mut builder = Uri::builder();
420 match req.uri().path_and_query() {
421 Some(path_and_query) => {
422 builder = builder.path_and_query(path_and_query.as_str());
423 }
424 _ => {
425 builder = builder.path_and_query("/");
426 }
427 }
428 *(req.uri_mut()) = builder.build()?;
429 }
430
431 r.try_send_request(req).await.map_err(Into::into)
432 }
433 Self::Http2(r) => {
434 if !matches!(req.version(), HttpVersion::HTTP_2) {
435 trace!("HTTP client changed Request.version to HTTP/2 from {:?}", req.version());
436
437 *req.version_mut() = HttpVersion::HTTP_2;
438 }
439
440 r.try_send_request(req).await.map_err(Into::into)
441 }
442 }
443 }
444
445 pub fn is_closed(&self) -> bool {
446 match self {
447 Self::Http1(r) => r.is_closed(),
448 Self::Http2(r) => r.is_closed(),
449 }
450 }
451}