tower_http_client/
rewrite_uri.rs1#![doc = include_str!("../examples/rewrite_uri.rs")]
30use std::task::{Context, Poll};
33
34use futures_util::future::{Either, Ready, ready};
35use tower_layer::Layer;
36use tower_service::Service;
37
38pub trait RewriteUri {
44 type Error;
46
47 fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error>;
49}
50
51impl<F, E> RewriteUri for F
52where
53 F: FnMut(&http::Uri) -> Result<http::Uri, E>,
54{
55 type Error = E;
56
57 fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error> {
58 self(uri)
59 }
60}
61
62#[derive(Debug, Clone)]
67pub struct RewriteUriLayer<R> {
68 rewrite: R,
69}
70
71impl<R> RewriteUriLayer<R> {
72 pub fn new(rewrite: R) -> Self {
74 Self { rewrite }
75 }
76}
77
78impl<S, R: Clone> Layer<S> for RewriteUriLayer<R> {
79 type Service = RewriteUriService<S, R>;
80
81 fn layer(&self, inner: S) -> Self::Service {
82 RewriteUriService::new(inner, self.rewrite.clone())
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct RewriteUriService<S, R> {
89 inner: S,
90 rewrite: R,
91}
92
93impl<S, R> RewriteUriService<S, R> {
94 pub fn new(inner: S, rewrite: R) -> Self {
96 Self { inner, rewrite }
97 }
98}
99
100impl<S, R, ReqBody> Service<http::Request<ReqBody>> for RewriteUriService<S, R>
101where
102 S: Service<http::Request<ReqBody>>,
103 R: RewriteUri,
104 R::Error: Into<S::Error>,
105{
106 type Response = S::Response;
107 type Error = S::Error;
108 type Future = Either<Ready<Result<S::Response, S::Error>>, S::Future>;
109
110 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111 self.inner.poll_ready(cx)
112 }
113
114 fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
115 match self.rewrite.rewrite_uri(req.uri()) {
116 Ok(new_uri) => {
117 *req.uri_mut() = new_uri;
118 Either::Right(self.inner.call(req))
119 }
120 Err(e) => Either::Left(ready(Err(e.into()))),
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use std::convert::Infallible;
128
129 use http::{Request, Response, Uri};
130 use tower::{ServiceBuilder, service_fn};
131 use tower_layer::Layer as _;
132 use tower_service::Service as _;
133
134 use super::{RewriteUri, RewriteUriLayer, RewriteUriService};
135
136 fn capture_uri_service()
138 -> impl tower_service::Service<Request<()>, Response = Response<String>, Error = Infallible>
139 {
140 service_fn(|req: Request<()>| async move {
141 Ok::<_, Infallible>(Response::new(req.uri().to_string()))
142 })
143 }
144
145 #[tokio::test]
146 async fn test_rewrite_uri_with_closure() {
147 let mut svc = RewriteUriService::new(capture_uri_service(), |_uri: &Uri| {
148 Ok::<_, Infallible>(Uri::from_static("http://example.com/rewritten"))
149 });
150
151 let response = svc
152 .call(Request::builder().uri("/original").body(()).unwrap())
153 .await
154 .unwrap();
155 assert_eq!(response.into_body(), "http://example.com/rewritten");
156 }
157
158 #[tokio::test]
159 async fn test_rewrite_uri_layer() {
160 let mut svc = RewriteUriLayer::new(|_uri: &Uri| {
161 Ok::<_, Infallible>(Uri::from_static("http://example.com/via-layer"))
162 })
163 .layer(capture_uri_service());
164
165 let req = Request::builder().uri("/original").body(()).unwrap();
166 let response = svc.call(req).await.unwrap();
167 assert_eq!(response.into_body(), "http://example.com/via-layer");
168 }
169
170 #[tokio::test]
171 async fn test_rewrite_uri_service_builder() {
172 let mut svc = ServiceBuilder::new()
173 .layer(RewriteUriLayer::new(|uri: &Uri| {
174 let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
175 let new_uri: Uri = format!("http://example.com{path}").parse().unwrap();
176 Ok::<_, Infallible>(new_uri)
177 }))
178 .service(capture_uri_service());
179
180 let req = Request::builder().uri("/hello").body(()).unwrap();
181 let response = svc.call(req).await.unwrap();
182 assert_eq!(response.into_body(), "http://example.com/hello");
183 }
184
185 #[tokio::test]
186 async fn test_rewrite_uri_error_propagates() {
187 let inner =
190 service_fn(|_: Request<()>| async { Ok::<_, String>(Response::new("ok".to_string())) });
191
192 let mut svc = RewriteUriService::new(inner, |_uri: &Uri| {
193 Err::<Uri, String>("rewrite failed".to_string())
194 });
195
196 let req = Request::builder().uri("/original").body(()).unwrap();
197 let result = svc.call(req).await;
198 assert!(result.is_err());
199 assert_eq!(result.unwrap_err(), "rewrite failed");
200 }
201
202 #[tokio::test]
203 async fn test_rewrite_uri_struct_impl() {
204 #[derive(Clone)]
205 struct PrependBase {
206 base: &'static str,
207 }
208
209 impl RewriteUri for PrependBase {
210 type Error = Infallible;
211
212 fn rewrite_uri(&mut self, uri: &Uri) -> Result<Uri, Self::Error> {
213 let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
214 Ok(format!("{}{path}", self.base).parse().unwrap())
215 }
216 }
217
218 let mut svc = RewriteUriLayer::new(PrependBase {
219 base: "http://backend.internal",
220 })
221 .layer(capture_uri_service());
222
223 let req = Request::builder().uri("/api/users").body(()).unwrap();
224 let response = svc.call(req).await.unwrap();
225 assert_eq!(response.into_body(), "http://backend.internal/api/users");
226 }
227}