Skip to main content

tower_http_client/
rewrite_uri.rs

1//! Middleware for rewriting request URIs.
2//!
3//! Use [`RewriteUriLayer`] when the client builds requests with relative URIs
4//! and the target host must be resolved later — for example, to implement
5//! load balancing or to switch between staging and production environments.
6//!
7//! # Example
8//!
9//! Routing requests to a backend chosen at runtime (e.g. load balancing):
10//!
11//! ```rust
12//! use http::Uri;
13//! use tower::ServiceBuilder;
14//! use tower_http_client::rewrite_uri::RewriteUriLayer;
15//!
16//! // Imagine `pick_node()` returns the address of the least-loaded backend.
17//! fn pick_node() -> &'static str { "http://node-3.internal" }
18//!
19//! let layer = RewriteUriLayer::new(|uri: &Uri| {
20//!     let base = pick_node();
21//!     let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
22//!     format!("{base}{path}").parse::<Uri>().map_err(http::Error::from)
23//! });
24//! ```
25//!
26//! Using a struct implementing [`RewriteUri`] to switch between environments:
27//!
28//! ```rust
29#![doc = include_str!("../examples/rewrite_uri.rs")]
30//! ```
31
32use std::task::{Context, Poll};
33
34use futures_util::future::{Either, Ready, ready};
35use tower_layer::Layer;
36use tower_service::Service;
37
38/// Trait for rewriting URIs on incoming requests.
39///
40/// Implement this trait to define custom URI rewriting logic.  A blanket
41/// implementation is provided for closures of the form
42/// `FnMut(&http::Uri) -> Result<http::Uri, E>`.
43pub trait RewriteUri {
44    /// The error type returned when rewriting fails.
45    type Error;
46
47    /// Rewrite the given URI, returning a new URI or an error.
48    fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error>;
49}
50
51impl<F, E> RewriteUri for F
52where
53    F: FnMut(&http::Uri) -> Result<http::Uri, E>,
54{
55    type Error = E;
56
57    fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error> {
58        self(uri)
59    }
60}
61
62/// Layer that applies URI rewriting to every request via a [`RewriteUri`] policy.
63///
64/// Wraps an inner service and rewrites the URI of each incoming request before
65/// forwarding it.
66#[derive(Debug, Clone)]
67pub struct RewriteUriLayer<R> {
68    rewrite: R,
69}
70
71impl<R> RewriteUriLayer<R> {
72    /// Create a new [`RewriteUriLayer`] with the given rewrite policy.
73    pub fn new(rewrite: R) -> Self {
74        Self { rewrite }
75    }
76}
77
78impl<S, R: Clone> Layer<S> for RewriteUriLayer<R> {
79    type Service = RewriteUriService<S, R>;
80
81    fn layer(&self, inner: S) -> Self::Service {
82        RewriteUriService::new(inner, self.rewrite.clone())
83    }
84}
85
86/// Middleware that rewrites the URI of each request using a [`RewriteUri`] policy.
87#[derive(Debug, Clone)]
88pub struct RewriteUriService<S, R> {
89    inner: S,
90    rewrite: R,
91}
92
93impl<S, R> RewriteUriService<S, R> {
94    /// Create a new [`RewriteUriService`].
95    pub fn new(inner: S, rewrite: R) -> Self {
96        Self { inner, rewrite }
97    }
98}
99
100impl<S, R, ReqBody> Service<http::Request<ReqBody>> for RewriteUriService<S, R>
101where
102    S: Service<http::Request<ReqBody>>,
103    R: RewriteUri,
104    R::Error: Into<S::Error>,
105{
106    type Response = S::Response;
107    type Error = S::Error;
108    type Future = Either<Ready<Result<S::Response, S::Error>>, S::Future>;
109
110    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111        self.inner.poll_ready(cx)
112    }
113
114    fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
115        match self.rewrite.rewrite_uri(req.uri()) {
116            Ok(new_uri) => {
117                *req.uri_mut() = new_uri;
118                Either::Right(self.inner.call(req))
119            }
120            Err(e) => Either::Left(ready(Err(e.into()))),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use std::convert::Infallible;
128
129    use http::{Request, Response, Uri};
130    use tower::{ServiceBuilder, service_fn};
131    use tower_layer::Layer as _;
132    use tower_service::Service as _;
133
134    use super::{RewriteUri, RewriteUriLayer, RewriteUriService};
135
136    /// A minimal service that returns the request URI as the response body.
137    fn capture_uri_service()
138    -> impl tower_service::Service<Request<()>, Response = Response<String>, Error = Infallible>
139    {
140        service_fn(|req: Request<()>| async move {
141            Ok::<_, Infallible>(Response::new(req.uri().to_string()))
142        })
143    }
144
145    #[tokio::test]
146    async fn test_rewrite_uri_with_closure() {
147        let mut svc = RewriteUriService::new(capture_uri_service(), |_uri: &Uri| {
148            Ok::<_, Infallible>(Uri::from_static("http://example.com/rewritten"))
149        });
150
151        let response = svc
152            .call(Request::builder().uri("/original").body(()).unwrap())
153            .await
154            .unwrap();
155        assert_eq!(response.into_body(), "http://example.com/rewritten");
156    }
157
158    #[tokio::test]
159    async fn test_rewrite_uri_layer() {
160        let mut svc = RewriteUriLayer::new(|_uri: &Uri| {
161            Ok::<_, Infallible>(Uri::from_static("http://example.com/via-layer"))
162        })
163        .layer(capture_uri_service());
164
165        let req = Request::builder().uri("/original").body(()).unwrap();
166        let response = svc.call(req).await.unwrap();
167        assert_eq!(response.into_body(), "http://example.com/via-layer");
168    }
169
170    #[tokio::test]
171    async fn test_rewrite_uri_service_builder() {
172        let mut svc = ServiceBuilder::new()
173            .layer(RewriteUriLayer::new(|uri: &Uri| {
174                let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
175                let new_uri: Uri = format!("http://example.com{path}").parse().unwrap();
176                Ok::<_, Infallible>(new_uri)
177            }))
178            .service(capture_uri_service());
179
180        let req = Request::builder().uri("/hello").body(()).unwrap();
181        let response = svc.call(req).await.unwrap();
182        assert_eq!(response.into_body(), "http://example.com/hello");
183    }
184
185    #[tokio::test]
186    async fn test_rewrite_uri_error_propagates() {
187        // Use String as a convenient non-Infallible error type for both service
188        // and rewriter so that String: Into<String> is satisfied.
189        let inner =
190            service_fn(|_: Request<()>| async { Ok::<_, String>(Response::new("ok".to_string())) });
191
192        let mut svc = RewriteUriService::new(inner, |_uri: &Uri| {
193            Err::<Uri, String>("rewrite failed".to_string())
194        });
195
196        let req = Request::builder().uri("/original").body(()).unwrap();
197        let result = svc.call(req).await;
198        assert!(result.is_err());
199        assert_eq!(result.unwrap_err(), "rewrite failed");
200    }
201
202    #[tokio::test]
203    async fn test_rewrite_uri_struct_impl() {
204        #[derive(Clone)]
205        struct PrependBase {
206            base: &'static str,
207        }
208
209        impl RewriteUri for PrependBase {
210            type Error = Infallible;
211
212            fn rewrite_uri(&mut self, uri: &Uri) -> Result<Uri, Self::Error> {
213                let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
214                Ok(format!("{}{path}", self.base).parse().unwrap())
215            }
216        }
217
218        let mut svc = RewriteUriLayer::new(PrependBase {
219            base: "http://backend.internal",
220        })
221        .layer(capture_uri_service());
222
223        let req = Request::builder().uri("/api/users").body(()).unwrap();
224        let response = svc.call(req).await.unwrap();
225        assert_eq!(response.into_body(), "http://backend.internal/api/users");
226    }
227}