static_web_server/
rewrites.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Module that allows to rewrite request URLs with pattern matching support.
7//!
8
9use headers::HeaderValue;
10use hyper::{Body, Request, Response, StatusCode, Uri, header::HOST};
11
12use crate::{
13    Error,
14    handler::RequestHandlerOpts,
15    redirects::{handle_error, replace_placeholders},
16    settings::{Rewrites, file::RedirectsKind},
17};
18
19/// Applies rewrite rules to a request if necessary.
20pub(crate) fn pre_process<T>(
21    opts: &RequestHandlerOpts,
22    req: &mut Request<T>,
23) -> Option<Result<Response<Body>, Error>> {
24    let rewrites = opts.advanced_opts.as_ref()?.rewrites.as_deref()?;
25    let uri_path = req.uri().path();
26
27    let matched = rewrite_uri_path(uri_path, Some(rewrites))?;
28    let dest = match replace_placeholders(uri_path, &matched.source, &matched.destination) {
29        Ok(dest) => dest,
30        Err(err) => return handle_error(err, opts, req),
31    };
32
33    if let Some(redirect_type) = &matched.redirect {
34        // Handle redirects
35        let loc = match HeaderValue::from_str(&dest) {
36            Ok(val) => val,
37            Err(err) => {
38                return handle_error(
39                    Error::new(err).context("invalid header value from current uri"),
40                    opts,
41                    req,
42                );
43            }
44        };
45        let mut resp = Response::new(Body::empty());
46        resp.headers_mut().insert(hyper::header::LOCATION, loc);
47        *resp.status_mut() = match redirect_type {
48            RedirectsKind::Permanent => StatusCode::MOVED_PERMANENTLY,
49            RedirectsKind::Temporary => StatusCode::FOUND,
50        };
51        Some(Ok(resp))
52    } else {
53        // Handle internal rewrites
54        *req.uri_mut() = match merge_uris(req.uri(), &dest) {
55            Ok(uri) => uri,
56            Err(err) => {
57                return handle_error(
58                    err.context("invalid rewrite target from current uri"),
59                    opts,
60                    req,
61                );
62            }
63        };
64
65        // Adjust Host header to allow rewriting to a different virtual host
66        if let Some(host) = req.uri().host() {
67            let mut host = host.to_owned();
68            if let Some(port) = req.uri().port_u16() {
69                host.push_str(&format!(":{port}"));
70            }
71            if let Ok(host) = host.parse() {
72                req.headers_mut().insert(HOST, host);
73            }
74        }
75
76        None
77    }
78}
79
80fn merge_uris(orig_uri: &Uri, new_uri: &str) -> Result<Uri, Error> {
81    let mut parts = new_uri.parse::<Uri>()?.into_parts();
82    if parts.scheme.is_none() {
83        parts.scheme = orig_uri.scheme().cloned();
84    }
85    if parts.authority.is_none() {
86        parts.authority = orig_uri.authority().cloned();
87    }
88    if parts.path_and_query.is_none() {
89        parts.path_and_query = orig_uri.path_and_query().cloned();
90    }
91    if let Some(path_and_query) = &mut parts.path_and_query {
92        if let (None, Some(query)) = (path_and_query.query(), orig_uri.query()) {
93            *path_and_query = [path_and_query.as_str(), "?", query]
94                .into_iter()
95                .collect::<String>()
96                .parse()?;
97        }
98    }
99    Ok(Uri::from_parts(parts)?)
100}
101
102/// It returns a rewrite's destination path if the current request uri
103/// matches against the provided rewrites array.
104pub fn rewrite_uri_path<'a>(
105    uri_path: &'a str,
106    rewrites_opts: Option<&'a [Rewrites]>,
107) -> Option<&'a Rewrites> {
108    if let Some(rewrites_vec) = rewrites_opts {
109        for rewrites_entry in rewrites_vec {
110            // Match source glob pattern against request uri path
111            if rewrites_entry.source.is_match(uri_path) {
112                return Some(rewrites_entry);
113            }
114        }
115    }
116
117    None
118}
119
120#[cfg(test)]
121mod tests {
122    use super::pre_process;
123    use crate::{
124        Error,
125        handler::RequestHandlerOpts,
126        settings::{Advanced, Rewrites, file::RedirectsKind},
127    };
128    use hyper::{Body, Request, Response, StatusCode, header::HOST};
129    use regex_lite::Regex;
130
131    fn make_request(host: &str, uri: &str) -> Request<Body> {
132        let mut builder = Request::builder();
133        if !host.is_empty() {
134            builder = builder.header("Host", host);
135        }
136        builder.method("GET").uri(uri).body(Body::empty()).unwrap()
137    }
138
139    fn get_rewrites() -> Vec<Rewrites> {
140        vec![
141            Rewrites {
142                source: Regex::new(r"/source1$").unwrap(),
143                destination: "/destination1".into(),
144                redirect: None,
145            },
146            Rewrites {
147                source: Regex::new(r"/source2$").unwrap(),
148                destination: "/destination2".into(),
149                redirect: Some(RedirectsKind::Temporary),
150            },
151            Rewrites {
152                source: Regex::new(r"/(prefix/)?(source3)/(.*)").unwrap(),
153                destination: "/destination3/$2/$3".into(),
154                redirect: Some(RedirectsKind::Permanent),
155            },
156            Rewrites {
157                source: Regex::new(r"/(source4)/(.*)").unwrap(),
158                destination: "http://example.net:1234/destination4/$1?$2".into(),
159                redirect: None,
160            },
161        ]
162    }
163
164    fn is_redirect(result: Option<Result<Response<Body>, Error>>) -> Option<(StatusCode, String)> {
165        if let Some(Ok(response)) = result {
166            let location = response.headers().get("Location")?.to_str().unwrap().into();
167            Some((response.status(), location))
168        } else {
169            None
170        }
171    }
172
173    #[test]
174    fn test_no_rewrites() {
175        let mut req = make_request("", "/");
176        assert!(
177            pre_process(
178                &RequestHandlerOpts {
179                    advanced_opts: None,
180                    ..Default::default()
181                },
182                &mut req
183            )
184            .is_none()
185        );
186        assert_eq!(req.uri(), "/");
187
188        let mut req = make_request("", "/");
189        assert!(
190            pre_process(
191                &RequestHandlerOpts {
192                    advanced_opts: Some(Advanced {
193                        rewrites: None,
194                        ..Default::default()
195                    }),
196                    ..Default::default()
197                },
198                &mut req
199            )
200            .is_none()
201        );
202        assert_eq!(req.uri(), "/");
203    }
204
205    #[test]
206    fn test_no_match() {
207        let mut req = make_request("example.com", "/source2/whatever");
208        assert!(
209            pre_process(
210                &RequestHandlerOpts {
211                    advanced_opts: Some(Advanced {
212                        rewrites: Some(get_rewrites()),
213                        ..Default::default()
214                    }),
215                    ..Default::default()
216                },
217                &mut req
218            )
219            .is_none()
220        );
221        assert_eq!(req.uri(), "/source2/whatever");
222    }
223
224    #[test]
225    fn test_match() {
226        let mut req = make_request("", "/source1?query");
227        assert!(
228            pre_process(
229                &RequestHandlerOpts {
230                    advanced_opts: Some(Advanced {
231                        rewrites: Some(get_rewrites()),
232                        ..Default::default()
233                    }),
234                    ..Default::default()
235                },
236                &mut req
237            )
238            .is_none()
239        );
240        assert_eq!(req.uri(), "/destination1?query");
241
242        let mut req = make_request("", "/source2");
243        assert_eq!(
244            is_redirect(pre_process(
245                &RequestHandlerOpts {
246                    advanced_opts: Some(Advanced {
247                        rewrites: Some(get_rewrites()),
248                        ..Default::default()
249                    }),
250                    ..Default::default()
251                },
252                &mut req
253            )),
254            Some((StatusCode::FOUND, "/destination2".into()))
255        );
256
257        let mut req = make_request("", "/source3/whatever");
258        assert_eq!(
259            is_redirect(pre_process(
260                &RequestHandlerOpts {
261                    advanced_opts: Some(Advanced {
262                        rewrites: Some(get_rewrites()),
263                        ..Default::default()
264                    }),
265                    ..Default::default()
266                },
267                &mut req
268            )),
269            Some((
270                StatusCode::MOVED_PERMANENTLY,
271                "/destination3/source3/whatever".into()
272            ))
273        );
274
275        let mut req = make_request("example.com", "/source4/whatever?query");
276        assert!(
277            pre_process(
278                &RequestHandlerOpts {
279                    advanced_opts: Some(Advanced {
280                        rewrites: Some(get_rewrites()),
281                        ..Default::default()
282                    }),
283                    ..Default::default()
284                },
285                &mut req
286            )
287            .is_none()
288        );
289        assert_eq!(
290            req.uri(),
291            "http://example.net:1234/destination4/source4?whatever"
292        );
293        assert_eq!(
294            req.headers()
295                .get(HOST)
296                .map(|h| h.to_str().unwrap())
297                .unwrap_or(""),
298            "example.net:1234"
299        );
300    }
301}