Skip to main content

warp_reverse_proxy/
lib.rs

1//! Fully composable [warp](https://github.com/seanmonstar/warp) filter that can be used as a reverse proxy. It forwards the request to the
2//! desired address and replies back the remote address response.
3//!
4//!
5//! ```no_run
6//! use warp::{Filter, Rejection, Reply, reply::Response};
7//! use warp_reverse_proxy::reverse_proxy_filter;
8//!
9//! async fn log_response(response: Response) -> Result<impl Reply, Rejection> {
10//!     println!("{:?}", response);
11//!     Ok(response)
12//! }
13//!
14//! #[tokio::main]
15//! async fn main() {
16//!     let hello = warp::path!("hello" / String).map(|name| format!("Hello, {}!", name));
17//!
18//!     // // spawn base server
19//!     tokio::spawn(warp::serve(hello).run(([0, 0, 0, 0], 8080)));
20//!
21//!     // Forward request to localhost in other port
22//!     let app = warp::path!("hello" / ..).and(
23//!         reverse_proxy_filter("".to_string(), "http://127.0.0.1:8080/".to_string())
24//!             .and_then(log_response),
25//!     );
26//!
27//!     // spawn proxy server
28//!     warp::serve(app).run(([0, 0, 0, 0], 3030)).await;
29//! }
30//! ```
31pub mod errors;
32
33use hyper::body::Bytes;
34use once_cell::sync::{Lazy, OnceCell};
35use reqwest::redirect::Policy;
36use unicase::Ascii;
37use warp::filters::path::FullPath;
38use warp::http::{HeaderMap, HeaderValue, Method as RequestMethod};
39use warp::{reply, Reply};
40use warp::{Filter, Rejection};
41
42/// Reverse proxy internal client
43///
44/// It can be overridden if needed calling `OnceCell::set` as follows:
45/// # Examples
46/// ```
47/// use warp_reverse_proxy::CLIENT;
48/// use reqwest::Client;
49///
50/// let client = Client::builder().build().expect("client goes boom...");
51/// CLIENT.set(client).expect("client is set");
52/// ```
53pub static CLIENT: OnceCell<reqwest::Client> = OnceCell::new();
54
55/// Alias of warp `FullPath`
56pub type Uri = FullPath;
57
58/// Alias of query parameters.
59///
60/// This is the type that holds the request query parameters.
61pub type QueryParameters = Option<String>;
62
63/// Alias of warp `Method`
64pub type Method = RequestMethod;
65
66/// Alias of warp `HeaderMap`
67pub type Headers = HeaderMap;
68
69/// Wrapper around a request data tuple.
70///
71/// It is the type that holds the request data extracted by the [`extract_request_data_filter`](fn.extract_request_data_filter.html) filter.
72pub type Request = (Uri, QueryParameters, Method, Headers, Bytes);
73
74/// Reverse proxy filter
75///
76/// Forwards the request to the desired location. It maps one to one, meaning
77/// that a request to `https://www.bar.foo/handle/this/path` forwarding to `https://www.other.location`
78/// will result in a request to `https://www.other.location/handle/this/path`.
79///
80/// # Arguments
81///
82/// * `base_path` - A string with the initial relative path of the endpoint.
83///   For example a `foo/` applied for an endpoint `foo/bar/` will result on a proxy to `bar/` (hence `/foo` is removed)
84///
85/// * `proxy_address` - Base proxy address to forward request.
86/// # Examples
87///
88/// When making a filter with a path `/handle/this/path` combined with a filter built
89/// with `reverse_proxy_filter("handle".to_string(), "localhost:8080")`
90/// will make that request arriving to `https://www.bar.foo/handle/this/path` be forwarded to `localhost:8080/this/path`
91pub fn reverse_proxy_filter(
92    base_path: String,
93    proxy_address: String,
94) -> impl Filter<Extract = (warp::reply::Response,), Error = Rejection> + Clone {
95    let proxy_address = warp::any().map(move || proxy_address.clone());
96    let base_path = warp::any().map(move || base_path.clone());
97    let data_filter = extract_request_data_filter();
98
99    proxy_address
100        .and(base_path)
101        .and(data_filter)
102        .and_then(proxy_to_and_forward_response)
103        .boxed()
104}
105
106/// Warp filter that extracts query parameters from the request, if they exist.
107pub fn query_params_filter(
108) -> impl Filter<Extract = (QueryParameters,), Error = std::convert::Infallible> + Clone {
109    warp::query::raw()
110        .map(Some)
111        .or_else(|_| async { Ok::<(QueryParameters,), std::convert::Infallible>((None,)) })
112}
113
114/// Warp filter that extracts the relative request path, method, headers map and body of a request.
115pub fn extract_request_data_filter(
116) -> impl Filter<Extract = Request, Error = warp::Rejection> + Clone {
117    warp::path::full()
118        .and(query_params_filter())
119        .and(warp::method())
120        .and(warp::header::headers_cloned())
121        .and(warp::body::bytes())
122}
123
124/// Build a request and send to the requested address.
125///
126/// Wraps the response into a `warp::reply` compatible type (`http::Response`)
127///
128/// # Arguments
129///
130/// * `proxy_address` - A string containing the base proxy address where the request
131///   will be forwarded to.
132///
133/// * `base_path` - A string with the prepended sub-path to be stripped from the request uri path.
134///
135/// * `uri` -> The uri of the extracted request.
136///
137/// * `params` -> The URL query parameters
138///
139/// * `method` -> The request method.
140///
141/// * `headers` -> The request headers.
142///
143/// * `body` -> The request body.
144///
145/// # Examples
146/// Notice that this method usually need to be used in aggregation with
147/// the [`extract_request_data_filter`](fn.extract_request_data_filter.html) filter which already
148/// provides the `(Uri, QueryParameters, Method, Headers, Body)` needed for calling this method. But the `proxy_address`
149/// and the `base_path` arguments need to be provided too.
150/// ```rust, ignore
151/// use warp::{Filter, Reply, Rejection, reply::Response};
152/// use warp_reverse_proxy::{extract_request_data_filter, proxy_to_and_forward_response};
153///
154/// async fn log_response(response: Response) -> Result<impl Reply, Rejection> {
155///     println!("{:?}", response);
156///     Ok(response)
157/// }
158///
159/// let request_filter = extract_request_data_filter();
160/// let app = warp::path!("hello" / String)
161///     .map(|port| (format!("http://127.0.0.1:{}/", port), "".to_string()))
162///     .untuple_one()
163///     .and(request_filter)
164///     .and_then(proxy_to_and_forward_response)
165///     .and_then(log_response);
166/// ```
167pub async fn proxy_to_and_forward_response(
168    proxy_address: String,
169    base_path: String,
170    uri: FullPath,
171    params: QueryParameters,
172    method: Method,
173    headers: HeaderMap,
174    body: Bytes,
175) -> Result<warp::reply::Response, Rejection> {
176    let proxy_uri = remove_relative_path(&uri, base_path, proxy_address);
177    let request = filtered_data_to_request(proxy_uri, (uri, params, method, headers, body))
178        .map_err(warp::reject::custom)?;
179    let response = proxy_request(request).await.map_err(warp::reject::custom)?;
180    response_to_reply(response)
181        .await
182        .map_err(warp::reject::custom)
183}
184
185/// Converts a reqwest response into a streaming warp response.
186async fn response_to_reply(
187    response: reqwest::Response,
188) -> Result<warp::reply::Response, errors::Error> {
189    let status = response.status();
190    let headers = response.headers().clone();
191    let mut reply = reply::stream(response.bytes_stream()).into_response();
192    *reply.status_mut() = status;
193    for (k, v) in remove_hop_headers(headers).into_iter() {
194        reply.headers_mut().insert(
195            k.expect("This is previously filters out in remove hop headers call"),
196            v,
197        );
198    }
199    Ok(reply)
200}
201
202fn remove_relative_path(uri: &FullPath, base_path: String, proxy_address: String) -> String {
203    let mut base_path = base_path;
204    if !base_path.starts_with('/') {
205        base_path = format!("/{}", base_path);
206    }
207    let relative_path = uri
208        .as_str()
209        .trim_start_matches(&base_path)
210        .trim_start_matches('/');
211
212    let proxy_address = proxy_address.trim_end_matches('/');
213    format!("{}/{}", proxy_address, relative_path)
214}
215
216/// Checker method to filter hop headers
217///
218/// Headers are checked using unicase to avoid case misfunctions
219fn is_hop_header(header_name: &str) -> bool {
220    static HOP_HEADERS: Lazy<Vec<Ascii<&'static str>>> = Lazy::new(|| {
221        vec![
222            Ascii::new("Connection"),
223            Ascii::new("Keep-Alive"),
224            Ascii::new("Proxy-Authenticate"),
225            Ascii::new("Proxy-Authorization"),
226            Ascii::new("Te"),
227            Ascii::new("Trailers"),
228            Ascii::new("Transfer-Encoding"),
229            Ascii::new("Upgrade"),
230        ]
231    });
232
233    HOP_HEADERS.iter().any(|h| h == &header_name)
234}
235
236fn remove_hop_headers(headers: HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
237    headers
238        .into_iter()
239        .filter_map(|(k, v)| {
240            if matches!(k, Some(ref k) if !is_hop_header(k.as_str())) {
241                Some((k.unwrap(), v))
242            } else {
243                None
244            }
245        })
246        .collect()
247}
248
249fn filtered_data_to_request(
250    proxy_address: String,
251    request: Request,
252) -> Result<reqwest::Request, errors::Error> {
253    let (_uri, params, method, headers, body) = request;
254
255    let proxy_uri = if let Some(params) = params {
256        format!("{}?{}", proxy_address, params)
257    } else {
258        proxy_address
259    };
260
261    let headers = remove_hop_headers(headers);
262
263    CLIENT
264        .get_or_init(default_reqwest_client)
265        .request(method, proxy_uri)
266        .headers(headers)
267        .body(body)
268        .build()
269        .map_err(errors::Error::Request)
270}
271
272/// Build and send a request to the specified address and request data
273async fn proxy_request(request: reqwest::Request) -> Result<reqwest::Response, errors::Error> {
274    CLIENT
275        .get_or_init(default_reqwest_client)
276        .execute(request)
277        .await
278        .map_err(errors::Error::Request)
279}
280
281/// Build a default client with redirect policy set to none
282fn default_reqwest_client() -> reqwest::Client {
283    reqwest::Client::builder()
284        .redirect(Policy::none())
285        .build()
286        // we should panic here, it is enforce that the client is needed, and there is no error
287        // handling possible on function call, better to stop execution.
288        .expect("Default reqwest client couldn't build")
289}
290
291#[cfg(test)]
292pub mod test {
293    use crate::{
294        extract_request_data_filter, filtered_data_to_request, proxy_request, remove_relative_path,
295        reverse_proxy_filter, Request,
296    };
297    use std::net::SocketAddr;
298    use warp::http::StatusCode;
299    use warp::Filter;
300
301    fn serve_test_response(path: String, address: SocketAddr) {
302        if path.is_empty() {
303            tokio::spawn(warp::serve(warp::any().map(warp::reply)).run(address));
304        } else {
305            tokio::spawn(warp::serve(warp::path(path).map(warp::reply)).run(address));
306        }
307    }
308
309    #[tokio::test]
310    async fn request_data_match() {
311        let filter = extract_request_data_filter();
312
313        let (path, query, method, body, header) =
314            ("/foo/bar", "foo=bar", "POST", b"foo bar", ("foo", "bar"));
315        let path_with_query = format!("{}?{}", path, query);
316
317        let result = warp::test::request()
318            .path(path_with_query.as_str())
319            .method(method)
320            .body(body)
321            .header(header.0, header.1)
322            .filter(&filter)
323            .await;
324
325        let (result_path, result_query, result_method, result_headers, result_body): Request =
326            result.unwrap();
327
328        assert_eq!(path, result_path.as_str());
329        assert_eq!(Some(query.to_string()), result_query);
330        assert_eq!(method, result_method.as_str());
331        assert_eq!(bytes::Bytes::from(body.to_vec()), result_body);
332        assert_eq!(result_headers.get(header.0).unwrap(), header.1);
333    }
334
335    #[tokio::test]
336    async fn proxy_forward_response() {
337        let filter = extract_request_data_filter();
338        let (path_with_params, method, body, header) = (
339            "http://127.0.0.1:3030/foo/bar?foo=bar",
340            "GET",
341            b"foo bar",
342            ("foo", "bar"),
343        );
344
345        let result = warp::test::request()
346            .path(path_with_params)
347            .method(method)
348            .body(body)
349            .header(header.0, header.1)
350            .filter(&filter)
351            .await;
352
353        let request: Request = result.unwrap();
354
355        let address = ([127, 0, 0, 1], 4040);
356        serve_test_response("".to_string(), address.into());
357
358        tokio::task::yield_now().await;
359        // transform request data into an actual request
360        let request = filtered_data_to_request(
361            remove_relative_path(
362                &request.0,
363                "".to_string(),
364                "http://127.0.0.1:4040".to_string(),
365            ),
366            request,
367        )
368        .unwrap();
369        let response = proxy_request(request).await.unwrap();
370        assert_eq!(response.status(), StatusCode::OK);
371    }
372
373    #[tokio::test]
374    async fn full_reverse_proxy_filter_forward_response() {
375        let address_str = "http://127.0.0.1:3030";
376        let filter = warp::path!("relative_path" / ..).and(reverse_proxy_filter(
377            "relative_path".to_string(),
378            address_str.to_string(),
379        ));
380        let address = ([127, 0, 0, 1], 3030);
381        let (path, method, body, header) = (
382            "https://127.0.0.1:3030/relative_path/foo",
383            "GET",
384            b"foo bar",
385            ("foo", "bar"),
386        );
387
388        serve_test_response("foo".to_string(), address.into());
389        tokio::task::yield_now().await;
390
391        let response = warp::test::request()
392            .path(path)
393            .method(method)
394            .body(body)
395            .header(header.0, header.1)
396            .reply(&filter)
397            .await;
398
399        assert_eq!(response.status(), StatusCode::OK);
400    }
401}