Skip to main content

rust_web_server/rewrite/
mod.rs

1//! Request and response rewriting middleware.
2//!
3//! [`RewriteLayer`] is a [`Middleware`] that transforms requests before they
4//! reach handlers and responses before they leave the server. Build one with the
5//! fluent builder API and add it to any [`crate::middleware::WithMiddleware`] stack.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use rust_web_server::app::App;
11//! use rust_web_server::core::New;
12//! use rust_web_server::rewrite::RewriteLayer;
13//!
14//! let app = App::new()
15//!     .wrap(RewriteLayer::new()
16//!         .request_header_set("X-Env", "production")
17//!         .request_uri_strip_prefix("/api/v1")
18//!         .response_header_set("Cache-Control", "no-store")
19//!         .response_body_replace("http://staging.internal", "https://example.com"));
20//! ```
21
22#[cfg(test)]
23mod tests;
24
25use crate::application::Application;
26use crate::header::Header;
27use crate::middleware::Middleware;
28use crate::request::Request;
29use crate::response::Response;
30use crate::server::ConnectionInfo;
31
32enum RequestRule {
33    SetHeader { name: String, value: String },
34    RemoveHeader(String),
35    SetUri(String),
36    StripUriPrefix(String),
37    AddUriPrefix(String),
38}
39
40enum ResponseRule {
41    SetHeader { name: String, value: String },
42    RemoveHeader(String),
43    SetStatus { code: i16, reason: String },
44    BodyReplace { from: Vec<u8>, to: Vec<u8> },
45}
46
47/// Composable request/response rewriting middleware.
48///
49/// Clones the incoming [`Request`], applies request rules, dispatches to the
50/// next handler, then applies response rules on the returned [`Response`].
51///
52/// All builder methods take `self` by value and return `Self` for chaining.
53pub struct RewriteLayer {
54    request_rules: Vec<RequestRule>,
55    response_rules: Vec<ResponseRule>,
56}
57
58impl RewriteLayer {
59    pub fn new() -> Self {
60        RewriteLayer { request_rules: Vec::new(), response_rules: Vec::new() }
61    }
62
63    /// Add or replace a request header (case-insensitive name match).
64    pub fn request_header_set(mut self, name: &str, value: &str) -> Self {
65        self.request_rules.push(RequestRule::SetHeader {
66            name: name.to_string(),
67            value: value.to_string(),
68        });
69        self
70    }
71
72    /// Remove a request header (case-insensitive).
73    pub fn request_header_remove(mut self, name: &str) -> Self {
74        self.request_rules.push(RequestRule::RemoveHeader(name.to_string()));
75        self
76    }
77
78    /// Replace the entire request URI.
79    pub fn request_uri_set(mut self, uri: &str) -> Self {
80        self.request_rules.push(RequestRule::SetUri(uri.to_string()));
81        self
82    }
83
84    /// Strip a path prefix from the request URI. No-op if the prefix is absent.
85    /// Normalizes to `"/"` if stripping leaves an empty path.
86    pub fn request_uri_strip_prefix(mut self, prefix: &str) -> Self {
87        self.request_rules.push(RequestRule::StripUriPrefix(prefix.to_string()));
88        self
89    }
90
91    /// Prepend a prefix to the request URI.
92    pub fn request_uri_add_prefix(mut self, prefix: &str) -> Self {
93        self.request_rules.push(RequestRule::AddUriPrefix(prefix.to_string()));
94        self
95    }
96
97    /// Add or replace a response header (case-insensitive name match).
98    pub fn response_header_set(mut self, name: &str, value: &str) -> Self {
99        self.response_rules.push(ResponseRule::SetHeader {
100            name: name.to_string(),
101            value: value.to_string(),
102        });
103        self
104    }
105
106    /// Remove a response header (case-insensitive).
107    pub fn response_header_remove(mut self, name: &str) -> Self {
108        self.response_rules.push(ResponseRule::RemoveHeader(name.to_string()));
109        self
110    }
111
112    /// Override the response status code and reason phrase.
113    pub fn response_status(mut self, code: i16, reason: &str) -> Self {
114        self.response_rules.push(ResponseRule::SetStatus { code, reason: reason.to_string() });
115        self
116    }
117
118    /// Byte-level find-and-replace across all response body content ranges.
119    pub fn response_body_replace(mut self, from: &str, to: &str) -> Self {
120        self.response_rules.push(ResponseRule::BodyReplace {
121            from: from.as_bytes().to_vec(),
122            to: to.as_bytes().to_vec(),
123        });
124        self
125    }
126}
127
128impl Middleware for RewriteLayer {
129    fn handle(
130        &self,
131        request: &Request,
132        connection: &ConnectionInfo,
133        next: &dyn Application,
134    ) -> Result<Response, String> {
135        let mut req = request.clone();
136
137        for rule in &self.request_rules {
138            match rule {
139                RequestRule::SetHeader { name, value } => {
140                    req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
141                    req.headers.push(Header { name: name.clone(), value: value.clone() });
142                }
143                RequestRule::RemoveHeader(name) => {
144                    req.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
145                }
146                RequestRule::SetUri(uri) => {
147                    req.request_uri = uri.clone();
148                }
149                RequestRule::StripUriPrefix(prefix) => {
150                    if let Some(stripped) = req.request_uri.strip_prefix(prefix.as_str()) {
151                        req.request_uri = if stripped.is_empty() || !stripped.starts_with('/') {
152                            format!("/{}", stripped)
153                        } else {
154                            stripped.to_string()
155                        };
156                    }
157                }
158                RequestRule::AddUriPrefix(prefix) => {
159                    req.request_uri = format!("{}{}", prefix, req.request_uri);
160                }
161            }
162        }
163
164        let mut response = next.execute(&req, connection)?;
165
166        for rule in &self.response_rules {
167            match rule {
168                ResponseRule::SetHeader { name, value } => {
169                    response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
170                    response.headers.push(Header { name: name.clone(), value: value.clone() });
171                }
172                ResponseRule::RemoveHeader(name) => {
173                    response.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
174                }
175                ResponseRule::SetStatus { code, reason } => {
176                    response.status_code = *code;
177                    response.reason_phrase = reason.clone();
178                }
179                ResponseRule::BodyReplace { from, to } => {
180                    for cr in &mut response.content_range_list {
181                        cr.body = replace_bytes(&cr.body, from, to);
182                    }
183                }
184            }
185        }
186
187        Ok(response)
188    }
189}
190
191fn replace_bytes(haystack: &[u8], needle: &[u8], replacement: &[u8]) -> Vec<u8> {
192    if needle.is_empty() {
193        return haystack.to_vec();
194    }
195    let mut result = Vec::with_capacity(haystack.len());
196    let mut i = 0;
197    while i + needle.len() <= haystack.len() {
198        if haystack[i..].starts_with(needle) {
199            result.extend_from_slice(replacement);
200            i += needle.len();
201        } else {
202            result.push(haystack[i]);
203            i += 1;
204        }
205    }
206    result.extend_from_slice(&haystack[i..]);
207    result
208}