Skip to main content

tiny_proxy/proxy/
handler.rs

1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14use crate::config::Config;
15use crate::proxy::ActionResult;
16
17use crate::proxy::directives::{
18    handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
19    handle_strip_prefix, handle_uri_replace,
20};
21
22/// Unified response body type - can handle both streaming (Incoming) and buffered (Full<Bytes>)
23/// This allows us to support SSE streaming while maintaining a simple API
24type ResponseBody =
25    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
26
27/// Check if header is hop-by-hop (should not be proxied)
28///
29/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
30/// These headers are meant for a single connection and should NOT be proxied
31/// Uses hyper::header constants for optimal performance (no allocations!)
32fn is_hop_header(name: &header::HeaderName) -> bool {
33    matches!(
34        name,
35        &header::CONNECTION
36            | &header::UPGRADE
37            | &header::TE
38            | &header::TRAILER
39            | &header::PROXY_AUTHENTICATE
40            | &header::PROXY_AUTHORIZATION
41    )
42}
43
44/// Process directives in order, applying modifications and returning final action
45/// Supports recursive handling of handle_path blocks
46pub fn process_directives(
47    directives: &[crate::config::Directive],
48    req: &mut Request<Incoming>,
49    current_path: &str,
50) -> Result<ActionResult, String> {
51    let mut modified_path = current_path.to_string();
52
53    for directive in directives {
54        match directive {
55            // Apply header modifications using directive handler
56
57            // Apply header modifications using directive handler
58            crate::config::Directive::Header { name, value } => {
59                if let Err(e) = handle_header(name, value.as_deref(), req) {
60                    info!("   Failed to apply header {}: {}", name, e);
61                }
62            }
63
64            // Apply URI replacements using directive handler
65            crate::config::Directive::UriReplace { find, replace } => {
66                handle_uri_replace(find, replace, &mut modified_path);
67            }
68
69            // Strip prefix from URI path
70            crate::config::Directive::StripPrefix { prefix } => {
71                handle_strip_prefix(prefix, &mut modified_path);
72            }
73
74            // Handle path-based routing recursively
75            crate::config::Directive::HandlePath {
76                pattern,
77                directives: nested_directives,
78            } => {
79                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
80                    info!("   Matched handle_path: {}", pattern);
81                    // Recursively process nested directives with remaining path
82                    return process_directives(nested_directives, req, &remaining_path);
83                }
84            }
85
86            // Method-based directives
87            crate::config::Directive::Method {
88                methods,
89                directives: nested_directives,
90            } => {
91                if handle_method(methods, req) {
92                    info!("   Matched method directive");
93                    // Process nested directives with same path
94                    return process_directives(nested_directives, req, &modified_path);
95                }
96            }
97
98            // Redirect - return redirect response with Location header
99            crate::config::Directive::Redirect { status, url } => {
100                return Ok(handle_redirect(status, url));
101            }
102
103            // Direct response - return immediately using directive handler
104            crate::config::Directive::Respond { status, body } => {
105                return Ok(handle_respond(status, body));
106            }
107
108            // Reverse proxy - return action using directive handler
109            crate::config::Directive::ReverseProxy {
110                to,
111                connect_timeout,
112                read_timeout,
113            } => {
114                return Ok(handle_reverse_proxy(
115                    to,
116                    &modified_path,
117                    *connect_timeout,
118                    *read_timeout,
119                ));
120            }
121        }
122    }
123
124    Err(format!(
125        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
126        current_path
127    ))
128}
129
130/// Process a single request through the proxy
131///
132/// This implementation ALWAYS streams backend responses (nginx-style):
133/// - No buffering of response body
134/// - Direct streaming from backend to client
135/// - Works for both SSE and regular HTTP
136/// - Optimal performance and memory usage
137///
138/// For direct responses (Respond directive) and errors, buffering is used
139/// since these are small and generated by the proxy itself
140pub async fn proxy(
141    mut req: Request<Incoming>,
142    client: Client<HttpsConnector<HttpConnector>, Incoming>,
143    config: Arc<Config>,
144    remote_addr: std::net::SocketAddr,
145) -> Result<Response<ResponseBody>, Error> {
146    // Get path from URI (using String to avoid borrow conflict with mutable req)
147    let path = req.uri().path().to_string();
148
149    // Get host from Host header (includes port, e.g., "localhost:8080")
150    let host = req
151        .headers()
152        .get(hyper::header::HOST)
153        .and_then(|h| h.to_str().ok())
154        .unwrap_or("localhost");
155
156    // Logging with enabled check to avoid string formatting when disabled
157    if tracing::enabled!(tracing::Level::INFO) {
158        // Removed info logging from hot path for performance
159        // Use DEBUG level if needed for troubleshooting
160    }
161
162    // Find site configuration by host (with port!)
163    let site_config = match config.sites.get(host) {
164        Some(config) => config,
165        None => {
166            error!("No configuration found for host: {}", host);
167            return Ok(error_response(
168                StatusCode::NOT_FOUND,
169                &format!("No configuration found for host: {}", host),
170            ));
171        }
172    };
173
174    // Process directives in correct order
175    let action_result =
176        process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;
177
178    // Execute action
179    match action_result {
180        ActionResult::Redirect { status, url } => {
181            let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
182            let boxed: ResponseBody = Full::new(Bytes::from(url.clone()))
183                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
184                .boxed();
185            Ok(Response::builder()
186                .status(status_code)
187                .header("Location", &url)
188                .body(boxed)?)
189        }
190        ActionResult::Respond { status, body } => {
191            let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
192            let boxed: ResponseBody = Full::new(Bytes::from(body))
193                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
194                .boxed();
195            Ok(Response::builder().status(status_code).body(boxed)?)
196        }
197        ActionResult::ReverseProxy {
198            backend_url,
199            path_to_send,
200            connect_timeout: _,
201            read_timeout,
202        } => {
203            // Add protocol if missing
204            let backend_with_proto =
205                if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
206                    backend_url
207                } else {
208                    format!("http://{}", backend_url)
209                };
210
211            // Use Uri::from_parts() instead of format!() + parse() - faster!
212            let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
213            parts.path_and_query = Some(path_to_send.parse()?);
214            let new_uri = Uri::from_parts(parts)?;
215
216            // Logging with enabled check to avoid string formatting when disabled
217            if tracing::enabled!(tracing::Level::INFO) {
218                // Removed info logging from hot path for performance
219                // Use DEBUG level if needed for troubleshooting
220            }
221
222            *req.uri_mut() = new_uri.clone();
223
224            // Save original host for X-Forwarded headers
225            // Clone HeaderValue directly - 0 allocations!
226            let original_host_header = req.headers().get(hyper::header::HOST).cloned();
227
228            // Update Host header for backend
229            req.headers_mut().remove(hyper::header::HOST);
230            if let Some(authority) = new_uri.authority() {
231                if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
232                    req.headers_mut().insert(hyper::header::HOST, host_value);
233                }
234            }
235
236            // Add X-Forwarded-* headers for backend visibility
237            // X-Forwarded-Host: original Host header from client
238            if let Some(host_value) = original_host_header.clone() {
239                req.headers_mut().insert("X-Forwarded-Host", host_value);
240            }
241
242            // X-Forwarded-Proto: scheme from original request (http or https)
243            let original_scheme = req.uri().scheme_str().unwrap_or("http");
244            // Use from_static for known values - 0 allocations!
245            match original_scheme {
246                "http" => {
247                    req.headers_mut().insert(
248                        "X-Forwarded-Proto",
249                        hyper::header::HeaderValue::from_static("http"),
250                    );
251                }
252                "https" => {
253                    req.headers_mut().insert(
254                        "X-Forwarded-Proto",
255                        hyper::header::HeaderValue::from_static("https"),
256                    );
257                }
258                _ => {} // ignore unknown schemes
259            }
260
261            // X-Forwarded-For: real client IP from TCP connection
262            if let Ok(ip_value) =
263                hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
264            {
265                req.headers_mut().insert("X-Forwarded-For", ip_value);
266            }
267
268            // Remove hop-by-hop headers from request before sending to backend
269            // Connection header must not be proxied (hyper manages connections)
270            req.headers_mut().remove(header::CONNECTION);
271
272            // Remove Accept-Encoding to prevent compression
273            // Compression breaks streaming and SSE
274            req.headers_mut().remove("accept-encoding");
275
276            // Forward request to backend with configurable timeout (default 30s)
277            let backend_timeout = read_timeout.unwrap_or(30);
278            match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
279                Ok(Ok(response)) => {
280                    // Successfully received response from backend
281                    let status = response.status();
282                    let headers = response.headers().clone();
283
284                    // Logging with enabled check to avoid string formatting when disabled
285                    if tracing::enabled!(tracing::Level::INFO) {
286                        // Removed info logging from hot path for performance
287                        // Use DEBUG level if needed for troubleshooting
288                    }
289
290                    // Stream response body directly (no buffering)
291                    let mut builder = Response::builder().status(status);
292
293                    // Copy all headers from backend, filtering out hop-by-hop headers
294                    // Hop-by-hop headers should not be proxied per RFC 7230
295                    // Also remove Content-Length to let hyper handle chunked encoding
296                    for (name, value) in headers.iter() {
297                        if !is_hop_header(name) && name != header::CONTENT_LENGTH {
298                            builder = builder.header(name, value);
299                        }
300                    }
301
302                    // Extract streaming body and convert to BoxBody
303                    let (_, incoming_body) = response.into_parts();
304                    let boxed: ResponseBody = incoming_body
305                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
306                        .boxed();
307
308                    Ok(builder.body(boxed)?)
309                }
310                Ok(Err(e)) => {
311                    // Backend unavailable - return 502 Bad Gateway
312                    error!("Backend connection failed: {:?}", e);
313
314                    if e.is_connect() {
315                        error!("   Reason: Connection refused - backend unavailable");
316                    } else {
317                        error!("   Reason: Other connection error");
318                    }
319
320                    Ok(error_response(
321                        StatusCode::BAD_GATEWAY,
322                        "Backend service unavailable",
323                    ))
324                }
325                Err(_) => {
326                    // Timeout - return 504 Gateway Timeout
327                    error!(
328                        "Backend request timed out after {} seconds",
329                        backend_timeout
330                    );
331
332                    Ok(error_response(
333                        StatusCode::GATEWAY_TIMEOUT,
334                        "Backend request timed out",
335                    ))
336                }
337            }
338        }
339    }
340}
341
342/// Creates HTTP response with error
343fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
344    let body = format!(
345        r#"<!DOCTYPE html>
346        <html>
347        <head><title>{} {}</title></head>
348        <body>
349        <h1>{} {}</h1>
350        <p>{}</p>
351        <hr>
352        <p><em>Rust Proxy Server</em></p>
353        </body>
354        </html>"#,
355        status.as_u16(),
356        status.canonical_reason().unwrap_or("Error"),
357        status.as_u16(),
358        status.canonical_reason().unwrap_or("Error"),
359        message
360    );
361
362    let full = Full::new(Bytes::from(body));
363    let boxed: ResponseBody = full
364        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
365        .boxed();
366
367    Response::builder()
368        .status(status)
369        .header("Content-Type", "text/html; charset=utf-8")
370        .body(boxed)
371        .unwrap()
372}
373
374/// Match path against pattern (supports wildcard *)
375/// Returns Some(remaining_path) if match, None otherwise
376pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
377    if let Some(prefix) = pattern.strip_suffix("/*") {
378        if path.starts_with(prefix) {
379            // Remove prefix and return remaining path
380            let remaining = path.strip_prefix(prefix).unwrap_or(path);
381            Some(remaining.to_string())
382        } else {
383            None
384        }
385    } else if pattern == path {
386        Some("/".to_string()) // Exact match, send root
387    } else {
388        None
389    }
390}