tiny_proxy/proxy/
handler.rs1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14use crate::config::Config;
15use crate::proxy::ActionResult;
16
17use crate::proxy::directives::{
18 handle_header, handle_method, handle_respond, handle_reverse_proxy, handle_uri_replace,
19};
20
21type ResponseBody =
24 http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
25
26fn is_hop_header(name: &header::HeaderName) -> bool {
32 matches!(
33 name,
34 &header::CONNECTION
35 | &header::UPGRADE
36 | &header::TE
37 | &header::TRAILER
38 | &header::PROXY_AUTHENTICATE
39 | &header::PROXY_AUTHORIZATION
40 )
41}
42
43pub fn process_directives(
46 directives: &[crate::config::Directive],
47 req: &mut Request<Incoming>,
48 current_path: &str,
49) -> Result<ActionResult, String> {
50 let mut modified_path = current_path.to_string();
51
52 for directive in directives {
53 match directive {
54 crate::config::Directive::Header { name, value } => {
58 if let Err(e) = handle_header(name, value, req) {
59 info!(" Failed to apply header {}: {}", name, e);
60 }
61 }
62
63 crate::config::Directive::UriReplace { find, replace } => {
65 handle_uri_replace(find, replace, &mut modified_path);
66 }
67
68 crate::config::Directive::HandlePath {
70 pattern,
71 directives: nested_directives,
72 } => {
73 if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
74 info!(" Matched handle_path: {}", pattern);
75 return process_directives(nested_directives, req, &remaining_path);
77 }
78 }
79
80 crate::config::Directive::Method {
82 methods,
83 directives: nested_directives,
84 } => {
85 if handle_method(methods, req) {
86 info!(" Matched method directive");
87 return process_directives(nested_directives, req, &modified_path);
89 }
90 }
91
92 crate::config::Directive::Respond { status, body } => {
94 return Ok(handle_respond(status, body));
95 }
96
97 crate::config::Directive::ReverseProxy { to } => {
99 return Ok(handle_reverse_proxy(to, &modified_path));
100 }
101 }
102 }
103
104 Err(format!(
105 "No action directive (respond or reverse_proxy) found in configuration for path: {}",
106 current_path
107 ))
108}
109
110pub async fn proxy(
121 mut req: Request<Incoming>,
122 client: Client<HttpsConnector<HttpConnector>, Incoming>,
123 config: Arc<Config>,
124) -> Result<Response<ResponseBody>, Error> {
125 let path = req.uri().path().to_string();
127
128 let host = req
130 .headers()
131 .get(hyper::header::HOST)
132 .and_then(|h| h.to_str().ok())
133 .unwrap_or("localhost");
134
135 if tracing::enabled!(tracing::Level::INFO) {
137 }
140
141 let site_config = match config.sites.get(host) {
143 Some(config) => config,
144 None => {
145 error!("No configuration found for host: {}", host);
146 return Ok(error_response(
147 StatusCode::NOT_FOUND,
148 &format!("No configuration found for host: {}", host),
149 ));
150 }
151 };
152
153 let action_result =
155 process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;
156
157 match action_result {
159 ActionResult::Respond { status, body } => {
160 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
161 let boxed: ResponseBody = Full::new(Bytes::from(body))
162 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
163 .boxed();
164 Ok(Response::builder().status(status_code).body(boxed).unwrap())
165 }
166 ActionResult::ReverseProxy {
167 backend_url,
168 path_to_send,
169 } => {
170 let backend_with_proto =
172 if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
173 backend_url
174 } else {
175 format!("http://{}", backend_url)
176 };
177
178 let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
180 parts.path_and_query = Some(path_to_send.parse()?);
181 let new_uri = Uri::from_parts(parts)?;
182
183 if tracing::enabled!(tracing::Level::INFO) {
185 }
188
189 *req.uri_mut() = new_uri.clone();
190
191 let original_host_header = req.headers().get(hyper::header::HOST).cloned();
194
195 req.headers_mut().remove(hyper::header::HOST);
197 if let Some(authority) = new_uri.authority() {
198 if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
199 req.headers_mut().insert(hyper::header::HOST, host_value);
200 }
201 }
202
203 if let Some(host_value) = original_host_header.clone() {
206 req.headers_mut().insert("X-Forwarded-Host", host_value);
207 }
208
209 let original_scheme = req.uri().scheme_str().unwrap_or("http");
211 match original_scheme {
213 "http" => {
214 req.headers_mut().insert(
215 "X-Forwarded-Proto",
216 hyper::header::HeaderValue::from_static("http"),
217 );
218 }
219 "https" => {
220 req.headers_mut().insert(
221 "X-Forwarded-Proto",
222 hyper::header::HeaderValue::from_static("https"),
223 );
224 }
225 _ => {} }
227
228 if let Some(for_value) = original_host_header {
232 req.headers_mut().insert("X-Forwarded-For", for_value);
233 }
234
235 req.headers_mut().remove(header::CONNECTION);
238
239 req.headers_mut().remove("accept-encoding");
242
243 match timeout(Duration::from_secs(30), client.request(req)).await {
245 Ok(Ok(response)) => {
246 let status = response.status();
248 let headers = response.headers().clone();
249
250 if tracing::enabled!(tracing::Level::INFO) {
252 }
255
256 let mut builder = Response::builder().status(status);
258
259 for (name, value) in headers.iter() {
263 if !is_hop_header(name) && name != header::CONTENT_LENGTH {
264 builder = builder.header(name, value);
265 }
266 }
267
268 let (_, incoming_body) = response.into_parts();
270 let boxed: ResponseBody = incoming_body
271 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
272 .boxed();
273
274 Ok(builder.body(boxed).unwrap())
275 }
276 Ok(Err(e)) => {
277 error!("Backend connection failed: {:?}", e);
279
280 if e.is_connect() {
281 error!(" Reason: Connection refused - backend unavailable");
282 } else {
283 error!(" Reason: Other connection error");
284 }
285
286 Ok(error_response(
287 StatusCode::BAD_GATEWAY,
288 "Backend service unavailable",
289 ))
290 }
291 Err(_) => {
292 error!("Backend request timed out after 30 seconds");
294
295 Ok(error_response(
296 StatusCode::GATEWAY_TIMEOUT,
297 "Backend request timed out",
298 ))
299 }
300 }
301 }
302 }
303}
304
305fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
307 let body = format!(
308 r#"<!DOCTYPE html>
309 <html>
310 <head><title>{} {}</title></head>
311 <body>
312 <h1>{} {}</h1>
313 <p>{}</p>
314 <hr>
315 <p><em>Rust Proxy Server</em></p>
316 </body>
317 </html>"#,
318 status.as_u16(),
319 status.canonical_reason().unwrap_or("Error"),
320 status.as_u16(),
321 status.canonical_reason().unwrap_or("Error"),
322 message
323 );
324
325 let full = Full::new(Bytes::from(body));
326 let boxed: ResponseBody = full
327 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
328 .boxed();
329
330 Response::builder()
331 .status(status)
332 .header("Content-Type", "text/html; charset=utf-8")
333 .body(boxed)
334 .unwrap()
335}
336
337pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
340 if let Some(prefix) = pattern.strip_suffix("/*") {
341 if path.starts_with(prefix) {
342 let remaining = path.strip_prefix(prefix).unwrap_or(path);
344 Some(remaining.to_string())
345 } else {
346 None
347 }
348 } else if pattern == path {
349 Some("/".to_string()) } else {
351 None
352 }
353}