Skip to main content

rust_web_server/rewrite/
mod.rs

1//! Request and response rewriting middleware.
2//!
3//! [`RewriteLayer`] is a [`Middleware`] that transforms requests before they
4//! reach handlers and responses before they leave the server. Build one with the
5//! fluent builder API and add it to any [`crate::middleware::WithMiddleware`] stack.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use rust_web_server::app::App;
11//! use rust_web_server::core::New;
12//! use rust_web_server::rewrite::RewriteLayer;
13//!
14//! let app = App::new()
15//!     .wrap(RewriteLayer::new()
16//!         .request_header_set("X-Env", "production")
17//!         .request_uri_strip_prefix("/api/v1")
18//!         .response_header_set("Cache-Control", "no-store")
19//!         .response_body_replace("http://staging.internal", "https://example.com"));
20//! ```
21//!
22//! # Regex URI rewriting (`rewrite-regex` feature)
23//!
24//! The prefix/set operations above cover fixed strings. When the rewrite
25//! depends on part of the incoming path — versioning schemes, locale
26//! prefixes, ID extraction — [`RewriteLayer::request_uri_regex_rewrite`]
27//! matches the URI against a regex and rewrites it using the match's capture
28//! groups, the same `rewrite` semantics as nginx: if the pattern matches
29//! anywhere in the URI, the **entire** URI is replaced by the expanded
30//! replacement string; otherwise the URI is left untouched.
31//!
32//! ```rust,no_run
33//! # #[cfg(feature = "rewrite-regex")]
34//! # fn example() -> Result<(), regex::Error> {
35//! use rust_web_server::app::App;
36//! use rust_web_server::core::New;
37//! use rust_web_server::rewrite::RewriteLayer;
38//!
39//! let app = App::new()
40//!     .wrap(RewriteLayer::new()
41//!         .request_uri_regex_rewrite(r"^/api/v\d+/(.*)$", "/$1")?);
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! Requires the `rewrite-regex` feature (adds the `regex` crate) — this is
47//! the one place in `rws` a third-party regex engine is worth the
48//! dependency; hand-rolling one is out of scope for this crate's "no
49//! third-party HTTP dependencies" philosophy, which doesn't extend to
50//! general-purpose text processing.
51
52#[cfg(test)]
53mod tests;
54
55use crate::application::Application;
56use crate::header::Header;
57use crate::middleware::Middleware;
58use crate::request::Request;
59use crate::response::Response;
60use crate::server::ConnectionInfo;
61
62#[cfg(feature = "rewrite-regex")]
63use regex::Regex;
64
65enum RequestRule {
66    SetHeader { name: String, value: String },
67    RemoveHeader(String),
68    SetUri(String),
69    StripUriPrefix(String),
70    AddUriPrefix(String),
71    #[cfg(feature = "rewrite-regex")]
72    RewriteUri { pattern: Regex, replacement: String },
73}
74
75enum ResponseRule {
76    SetHeader { name: String, value: String },
77    RemoveHeader(String),
78    SetStatus { code: i16, reason: String },
79    BodyReplace { from: Vec<u8>, to: Vec<u8> },
80}
81
82/// Composable request/response rewriting middleware.
83///
84/// Clones the incoming [`Request`], applies request rules, dispatches to the
85/// next handler, then applies response rules on the returned [`Response`].
86///
87/// All builder methods take `self` by value and return `Self` for chaining.
88pub struct RewriteLayer {
89    request_rules: Vec<RequestRule>,
90    response_rules: Vec<ResponseRule>,
91}
92
93impl RewriteLayer {
94    pub fn new() -> Self {
95        RewriteLayer { request_rules: Vec::new(), response_rules: Vec::new() }
96    }
97
98    /// Add or replace a request header (case-insensitive name match).
99    pub fn request_header_set(mut self, name: &str, value: &str) -> Self {
100        self.request_rules.push(RequestRule::SetHeader {
101            name: name.to_string(),
102            value: value.to_string(),
103        });
104        self
105    }
106
107    /// Remove a request header (case-insensitive).
108    pub fn request_header_remove(mut self, name: &str) -> Self {
109        self.request_rules.push(RequestRule::RemoveHeader(name.to_string()));
110        self
111    }
112
113    /// Replace the entire request URI.
114    pub fn request_uri_set(mut self, uri: &str) -> Self {
115        self.request_rules.push(RequestRule::SetUri(uri.to_string()));
116        self
117    }
118
119    /// Strip a path prefix from the request URI. No-op if the prefix is absent.
120    /// Normalizes to `"/"` if stripping leaves an empty path.
121    pub fn request_uri_strip_prefix(mut self, prefix: &str) -> Self {
122        self.request_rules.push(RequestRule::StripUriPrefix(prefix.to_string()));
123        self
124    }
125
126    /// Prepend a prefix to the request URI.
127    pub fn request_uri_add_prefix(mut self, prefix: &str) -> Self {
128        self.request_rules.push(RequestRule::AddUriPrefix(prefix.to_string()));
129        self
130    }
131
132    /// Rewrite the request URI by regex, nginx `rewrite`-directive style.
133    ///
134    /// If `pattern` matches anywhere in the URI, the **entire** URI is
135    /// replaced by `replacement` with capture-group references expanded —
136    /// `$1`, `$2`, ... for numbered groups, `${name}` for named groups
137    /// (`(?P<name>...)`), or `$0`/`${0}` for the whole match. If `pattern`
138    /// does not match, the URI is left unchanged.
139    ///
140    /// Returns `Err` if `pattern` is not a valid regex. Requires the
141    /// `rewrite-regex` feature.
142    #[cfg(feature = "rewrite-regex")]
143    pub fn request_uri_regex_rewrite(mut self, pattern: &str, replacement: &str) -> Result<Self, regex::Error> {
144        let compiled = Regex::new(pattern)?;
145        self.request_rules.push(RequestRule::RewriteUri {
146            pattern: compiled,
147            replacement: replacement.to_string(),
148        });
149        Ok(self)
150    }
151
152    /// Add or replace a response header (case-insensitive name match).
153    pub fn response_header_set(mut self, name: &str, value: &str) -> Self {
154        self.response_rules.push(ResponseRule::SetHeader {
155            name: name.to_string(),
156            value: value.to_string(),
157        });
158        self
159    }
160
161    /// Remove a response header (case-insensitive).
162    pub fn response_header_remove(mut self, name: &str) -> Self {
163        self.response_rules.push(ResponseRule::RemoveHeader(name.to_string()));
164        self
165    }
166
167    /// Override the response status code and reason phrase.
168    pub fn response_status(mut self, code: i16, reason: &str) -> Self {
169        self.response_rules.push(ResponseRule::SetStatus { code, reason: reason.to_string() });
170        self
171    }
172
173    /// Byte-level find-and-replace across all response body content ranges.
174    pub fn response_body_replace(mut self, from: &str, to: &str) -> Self {
175        self.response_rules.push(ResponseRule::BodyReplace {
176            from: from.as_bytes().to_vec(),
177            to: to.as_bytes().to_vec(),
178        });
179        self
180    }
181}
182
183impl Middleware for RewriteLayer {
184    fn handle(
185        &self,
186        request: &Request,
187        connection: &ConnectionInfo,
188        next: &dyn Application,
189    ) -> Result<Response, String> {
190        let mut req = request.clone();
191
192        for rule in &self.request_rules {
193            match rule {
194                RequestRule::SetHeader { name, value } => {
195                    req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
196                    req.headers.push(Header { name: name.clone(), value: value.clone() });
197                }
198                RequestRule::RemoveHeader(name) => {
199                    req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
200                }
201                RequestRule::SetUri(uri) => {
202                    req.request_uri = uri.clone();
203                }
204                RequestRule::StripUriPrefix(prefix) => {
205                    if let Some(stripped) = req.request_uri.strip_prefix(prefix.as_str()) {
206                        req.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
207                            format!("/{}", stripped)
208                        } else {
209                            stripped.to_string()
210                        };
211                    }
212                }
213                RequestRule::AddUriPrefix(prefix) => {
214                    req.request_uri = format!("{}{}", prefix, req.request_uri);
215                }
216                #[cfg(feature = "rewrite-regex")]
217                RequestRule::RewriteUri { pattern, replacement } => {
218                    if let Some(captures) = pattern.captures(&req.request_uri) {
219                        let mut expanded = String::new();
220                        captures.expand(replacement, &mut expanded);
221                        req.request_uri = expanded;
222                    }
223                }
224            }
225        }
226
227        let mut response = next.execute(&req, connection)?;
228
229        for rule in &self.response_rules {
230            match rule {
231                ResponseRule::SetHeader { name, value } => {
232                    response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
233                    response.headers.push(Header { name: name.clone(), value: value.clone() });
234                }
235                ResponseRule::RemoveHeader(name) => {
236                    response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
237                }
238                ResponseRule::SetStatus { code, reason } => {
239                    response.status_code = *code;
240                    response.reason_phrase = reason.clone();
241                }
242                ResponseRule::BodyReplace { from, to } => {
243                    for cr in &mut response.content_range_list {
244                        cr.body = replace_bytes(&cr.body, from, to);
245                    }
246                }
247            }
248        }
249
250        Ok(response)
251    }
252}
253
254fn replace_bytes(haystack: &[u8], needle: &[u8], replacement: &[u8]) -> Vec<u8> {
255    if needle.is_empty() {
256        return haystack.to_vec();
257    }
258    let mut result = Vec::with_capacity(haystack.len());
259    let mut i = 0;
260    while i + needle.len() <= haystack.len() {
261        if haystack[i..].starts_with(needle) {
262            result.extend_from_slice(replacement);
263            i += needle.len();
264        } else {
265            result.push(haystack[i]);
266            i += 1;
267        }
268    }
269    result.extend_from_slice(&haystack[i..]);
270    result
271}