1pub mod errors;
32
33use hyper::body::Bytes;
34use once_cell::sync::{Lazy, OnceCell};
35use reqwest::redirect::Policy;
36use unicase::Ascii;
37use warp::filters::path::FullPath;
38use warp::http::{HeaderMap, HeaderValue, Method as RequestMethod};
39use warp::{reply, Reply};
40use warp::{Filter, Rejection};
41
42pub static CLIENT: OnceCell<reqwest::Client> = OnceCell::new();
54
55pub type Uri = FullPath;
57
58pub type QueryParameters = Option<String>;
62
63pub type Method = RequestMethod;
65
66pub type Headers = HeaderMap;
68
69pub type Request = (Uri, QueryParameters, Method, Headers, Bytes);
73
74pub fn reverse_proxy_filter(
92 base_path: String,
93 proxy_address: String,
94) -> impl Filter<Extract = (warp::reply::Response,), Error = Rejection> + Clone {
95 let proxy_address = warp::any().map(move || proxy_address.clone());
96 let base_path = warp::any().map(move || base_path.clone());
97 let data_filter = extract_request_data_filter();
98
99 proxy_address
100 .and(base_path)
101 .and(data_filter)
102 .and_then(proxy_to_and_forward_response)
103 .boxed()
104}
105
106pub fn query_params_filter(
108) -> impl Filter<Extract = (QueryParameters,), Error = std::convert::Infallible> + Clone {
109 warp::query::raw()
110 .map(Some)
111 .or_else(|_| async { Ok::<(QueryParameters,), std::convert::Infallible>((None,)) })
112}
113
114pub fn extract_request_data_filter(
116) -> impl Filter<Extract = Request, Error = warp::Rejection> + Clone {
117 warp::path::full()
118 .and(query_params_filter())
119 .and(warp::method())
120 .and(warp::header::headers_cloned())
121 .and(warp::body::bytes())
122}
123
124pub async fn proxy_to_and_forward_response(
168 proxy_address: String,
169 base_path: String,
170 uri: FullPath,
171 params: QueryParameters,
172 method: Method,
173 headers: HeaderMap,
174 body: Bytes,
175) -> Result<warp::reply::Response, Rejection> {
176 let proxy_uri = remove_relative_path(&uri, base_path, proxy_address);
177 let request = filtered_data_to_request(proxy_uri, (uri, params, method, headers, body))
178 .map_err(warp::reject::custom)?;
179 let response = proxy_request(request).await.map_err(warp::reject::custom)?;
180 response_to_reply(response)
181 .await
182 .map_err(warp::reject::custom)
183}
184
185async fn response_to_reply(
187 response: reqwest::Response,
188) -> Result<warp::reply::Response, errors::Error> {
189 let status = response.status();
190 let headers = response.headers().clone();
191 let mut reply = reply::stream(response.bytes_stream()).into_response();
192 *reply.status_mut() = status;
193 for (k, v) in remove_hop_headers(headers).into_iter() {
194 reply.headers_mut().insert(
195 k.expect("This is previously filters out in remove hop headers call"),
196 v,
197 );
198 }
199 Ok(reply)
200}
201
202fn remove_relative_path(uri: &FullPath, base_path: String, proxy_address: String) -> String {
203 let mut base_path = base_path;
204 if !base_path.starts_with('/') {
205 base_path = format!("/{}", base_path);
206 }
207 let relative_path = uri
208 .as_str()
209 .trim_start_matches(&base_path)
210 .trim_start_matches('/');
211
212 let proxy_address = proxy_address.trim_end_matches('/');
213 format!("{}/{}", proxy_address, relative_path)
214}
215
216fn is_hop_header(header_name: &str) -> bool {
220 static HOP_HEADERS: Lazy<Vec<Ascii<&'static str>>> = Lazy::new(|| {
221 vec![
222 Ascii::new("Connection"),
223 Ascii::new("Keep-Alive"),
224 Ascii::new("Proxy-Authenticate"),
225 Ascii::new("Proxy-Authorization"),
226 Ascii::new("Te"),
227 Ascii::new("Trailers"),
228 Ascii::new("Transfer-Encoding"),
229 Ascii::new("Upgrade"),
230 ]
231 });
232
233 HOP_HEADERS.iter().any(|h| h == &header_name)
234}
235
236fn remove_hop_headers(headers: HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
237 headers
238 .into_iter()
239 .filter_map(|(k, v)| {
240 if matches!(k, Some(ref k) if !is_hop_header(k.as_str())) {
241 Some((k.unwrap(), v))
242 } else {
243 None
244 }
245 })
246 .collect()
247}
248
249fn filtered_data_to_request(
250 proxy_address: String,
251 request: Request,
252) -> Result<reqwest::Request, errors::Error> {
253 let (_uri, params, method, headers, body) = request;
254
255 let proxy_uri = if let Some(params) = params {
256 format!("{}?{}", proxy_address, params)
257 } else {
258 proxy_address
259 };
260
261 let headers = remove_hop_headers(headers);
262
263 CLIENT
264 .get_or_init(default_reqwest_client)
265 .request(method, proxy_uri)
266 .headers(headers)
267 .body(body)
268 .build()
269 .map_err(errors::Error::Request)
270}
271
272async fn proxy_request(request: reqwest::Request) -> Result<reqwest::Response, errors::Error> {
274 CLIENT
275 .get_or_init(default_reqwest_client)
276 .execute(request)
277 .await
278 .map_err(errors::Error::Request)
279}
280
281fn default_reqwest_client() -> reqwest::Client {
283 reqwest::Client::builder()
284 .redirect(Policy::none())
285 .build()
286 .expect("Default reqwest client couldn't build")
289}
290
291#[cfg(test)]
292pub mod test {
293 use crate::{
294 extract_request_data_filter, filtered_data_to_request, proxy_request, remove_relative_path,
295 reverse_proxy_filter, Request,
296 };
297 use std::net::SocketAddr;
298 use warp::http::StatusCode;
299 use warp::Filter;
300
301 fn serve_test_response(path: String, address: SocketAddr) {
302 if path.is_empty() {
303 tokio::spawn(warp::serve(warp::any().map(warp::reply)).run(address));
304 } else {
305 tokio::spawn(warp::serve(warp::path(path).map(warp::reply)).run(address));
306 }
307 }
308
309 #[tokio::test]
310 async fn request_data_match() {
311 let filter = extract_request_data_filter();
312
313 let (path, query, method, body, header) =
314 ("/foo/bar", "foo=bar", "POST", b"foo bar", ("foo", "bar"));
315 let path_with_query = format!("{}?{}", path, query);
316
317 let result = warp::test::request()
318 .path(path_with_query.as_str())
319 .method(method)
320 .body(body)
321 .header(header.0, header.1)
322 .filter(&filter)
323 .await;
324
325 let (result_path, result_query, result_method, result_headers, result_body): Request =
326 result.unwrap();
327
328 assert_eq!(path, result_path.as_str());
329 assert_eq!(Some(query.to_string()), result_query);
330 assert_eq!(method, result_method.as_str());
331 assert_eq!(bytes::Bytes::from(body.to_vec()), result_body);
332 assert_eq!(result_headers.get(header.0).unwrap(), header.1);
333 }
334
335 #[tokio::test]
336 async fn proxy_forward_response() {
337 let filter = extract_request_data_filter();
338 let (path_with_params, method, body, header) = (
339 "http://127.0.0.1:3030/foo/bar?foo=bar",
340 "GET",
341 b"foo bar",
342 ("foo", "bar"),
343 );
344
345 let result = warp::test::request()
346 .path(path_with_params)
347 .method(method)
348 .body(body)
349 .header(header.0, header.1)
350 .filter(&filter)
351 .await;
352
353 let request: Request = result.unwrap();
354
355 let address = ([127, 0, 0, 1], 4040);
356 serve_test_response("".to_string(), address.into());
357
358 tokio::task::yield_now().await;
359 let request = filtered_data_to_request(
361 remove_relative_path(
362 &request.0,
363 "".to_string(),
364 "http://127.0.0.1:4040".to_string(),
365 ),
366 request,
367 )
368 .unwrap();
369 let response = proxy_request(request).await.unwrap();
370 assert_eq!(response.status(), StatusCode::OK);
371 }
372
373 #[tokio::test]
374 async fn full_reverse_proxy_filter_forward_response() {
375 let address_str = "http://127.0.0.1:3030";
376 let filter = warp::path!("relative_path" / ..).and(reverse_proxy_filter(
377 "relative_path".to_string(),
378 address_str.to_string(),
379 ));
380 let address = ([127, 0, 0, 1], 3030);
381 let (path, method, body, header) = (
382 "https://127.0.0.1:3030/relative_path/foo",
383 "GET",
384 b"foo bar",
385 ("foo", "bar"),
386 );
387
388 serve_test_response("foo".to_string(), address.into());
389 tokio::task::yield_now().await;
390
391 let response = warp::test::request()
392 .path(path)
393 .method(method)
394 .body(body)
395 .header(header.0, header.1)
396 .reply(&filter)
397 .await;
398
399 assert_eq!(response.status(), StatusCode::OK);
400 }
401}