Skip to main content

tiny_proxy/proxy/
handler.rs

1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14use crate::config::Config;
15use crate::proxy::ActionResult;
16
17use crate::proxy::directives::{
18    handle_header, handle_method, handle_respond, handle_reverse_proxy, handle_uri_replace,
19};
20
21/// Unified response body type - can handle both streaming (Incoming) and buffered (Full<Bytes>)
22/// This allows us to support SSE streaming while maintaining a simple API
23type ResponseBody =
24    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
25
26/// Check if header is hop-by-hop (should not be proxied)
27///
28/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
29/// These headers are meant for a single connection and should NOT be proxied
30/// Uses hyper::header constants for optimal performance (no allocations!)
31fn is_hop_header(name: &header::HeaderName) -> bool {
32    matches!(
33        name,
34        &header::CONNECTION
35            | &header::UPGRADE
36            | &header::TE
37            | &header::TRAILER
38            | &header::PROXY_AUTHENTICATE
39            | &header::PROXY_AUTHORIZATION
40    )
41}
42
43/// Process directives in order, applying modifications and returning final action
44/// Supports recursive handling of handle_path blocks
45pub fn process_directives(
46    directives: &[crate::config::Directive],
47    req: &mut Request<Incoming>,
48    current_path: &str,
49) -> Result<ActionResult, String> {
50    let mut modified_path = current_path.to_string();
51
52    for directive in directives {
53        match directive {
54            // Apply header modifications using directive handler
55
56            // Apply header modifications using directive handler
57            crate::config::Directive::Header { name, value } => {
58                if let Err(e) = handle_header(name, value, req) {
59                    info!("   Failed to apply header {}: {}", name, e);
60                }
61            }
62
63            // Apply URI replacements using directive handler
64            crate::config::Directive::UriReplace { find, replace } => {
65                handle_uri_replace(find, replace, &mut modified_path);
66            }
67
68            // Handle path-based routing recursively
69            crate::config::Directive::HandlePath {
70                pattern,
71                directives: nested_directives,
72            } => {
73                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
74                    info!("   Matched handle_path: {}", pattern);
75                    // Recursively process nested directives with remaining path
76                    return process_directives(nested_directives, req, &remaining_path);
77                }
78            }
79
80            // Method-based directives
81            crate::config::Directive::Method {
82                methods,
83                directives: nested_directives,
84            } => {
85                if handle_method(methods, req) {
86                    info!("   Matched method directive");
87                    // Process nested directives with same path
88                    return process_directives(nested_directives, req, &modified_path);
89                }
90            }
91
92            // Direct response - return immediately using directive handler
93            crate::config::Directive::Respond { status, body } => {
94                return Ok(handle_respond(status, body));
95            }
96
97            // Reverse proxy - return action using directive handler
98            crate::config::Directive::ReverseProxy { to } => {
99                return Ok(handle_reverse_proxy(to, &modified_path));
100            }
101        }
102    }
103
104    Err(format!(
105        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
106        current_path
107    ))
108}
109
110/// Process a single request through the proxy
111///
112/// This implementation ALWAYS streams backend responses (nginx-style):
113/// - No buffering of response body
114/// - Direct streaming from backend to client
115/// - Works for both SSE and regular HTTP
116/// - Optimal performance and memory usage
117///
118/// For direct responses (Respond directive) and errors, buffering is used
119/// since these are small and generated by the proxy itself
120pub async fn proxy(
121    mut req: Request<Incoming>,
122    client: Client<HttpsConnector<HttpConnector>, Incoming>,
123    config: Arc<Config>,
124) -> Result<Response<ResponseBody>, Error> {
125    // Get path from URI (using String to avoid borrow conflict with mutable req)
126    let path = req.uri().path().to_string();
127
128    // Get host from Host header (includes port, e.g., "localhost:8080")
129    let host = req
130        .headers()
131        .get(hyper::header::HOST)
132        .and_then(|h| h.to_str().ok())
133        .unwrap_or("localhost");
134
135    // Logging with enabled check to avoid string formatting when disabled
136    if tracing::enabled!(tracing::Level::INFO) {
137        // Removed info logging from hot path for performance
138        // Use DEBUG level if needed for troubleshooting
139    }
140
141    // Find site configuration by host (with port!)
142    let site_config = match config.sites.get(host) {
143        Some(config) => config,
144        None => {
145            error!("No configuration found for host: {}", host);
146            return Ok(error_response(
147                StatusCode::NOT_FOUND,
148                &format!("No configuration found for host: {}", host),
149            ));
150        }
151    };
152
153    // Process directives in correct order
154    let action_result =
155        process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;
156
157    // Execute action
158    match action_result {
159        ActionResult::Respond { status, body } => {
160            let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
161            let boxed: ResponseBody = Full::new(Bytes::from(body))
162                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
163                .boxed();
164            Ok(Response::builder().status(status_code).body(boxed).unwrap())
165        }
166        ActionResult::ReverseProxy {
167            backend_url,
168            path_to_send,
169        } => {
170            // Add protocol if missing
171            let backend_with_proto =
172                if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
173                    backend_url
174                } else {
175                    format!("http://{}", backend_url)
176                };
177
178            // Use Uri::from_parts() instead of format!() + parse() - faster!
179            let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
180            parts.path_and_query = Some(path_to_send.parse()?);
181            let new_uri = Uri::from_parts(parts)?;
182
183            // Logging with enabled check to avoid string formatting when disabled
184            if tracing::enabled!(tracing::Level::INFO) {
185                // Removed info logging from hot path for performance
186                // Use DEBUG level if needed for troubleshooting
187            }
188
189            *req.uri_mut() = new_uri.clone();
190
191            // Save original host for X-Forwarded headers
192            // Clone HeaderValue directly - 0 allocations!
193            let original_host_header = req.headers().get(hyper::header::HOST).cloned();
194
195            // Update Host header for backend
196            req.headers_mut().remove(hyper::header::HOST);
197            if let Some(authority) = new_uri.authority() {
198                if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
199                    req.headers_mut().insert(hyper::header::HOST, host_value);
200                }
201            }
202
203            // Add X-Forwarded-* headers for backend visibility
204            // X-Forwarded-Host: original Host header from client
205            if let Some(host_value) = original_host_header.clone() {
206                req.headers_mut().insert("X-Forwarded-Host", host_value);
207            }
208
209            // X-Forwarded-Proto: scheme from original request (http or https)
210            let original_scheme = req.uri().scheme_str().unwrap_or("http");
211            // Use from_static for known values - 0 allocations!
212            match original_scheme {
213                "http" => {
214                    req.headers_mut().insert(
215                        "X-Forwarded-Proto",
216                        hyper::header::HeaderValue::from_static("http"),
217                    );
218                }
219                "https" => {
220                    req.headers_mut().insert(
221                        "X-Forwarded-Proto",
222                        hyper::header::HeaderValue::from_static("https"),
223                    );
224                }
225                _ => {} // ignore unknown schemes
226            }
227
228            // X-Forwarded-For: client IP address
229            // TODO: Extract real client IP from connection info
230            // For now, we can use the original host as a placeholder
231            if let Some(for_value) = original_host_header {
232                req.headers_mut().insert("X-Forwarded-For", for_value);
233            }
234
235            // Remove hop-by-hop headers from request before sending to backend
236            // Connection header must not be proxied (hyper manages connections)
237            req.headers_mut().remove(header::CONNECTION);
238
239            // Remove Accept-Encoding to prevent compression
240            // Compression breaks streaming and SSE
241            req.headers_mut().remove("accept-encoding");
242
243            // Forward request to backend with 30 second timeout
244            match timeout(Duration::from_secs(30), client.request(req)).await {
245                Ok(Ok(response)) => {
246                    // Successfully received response from backend
247                    let status = response.status();
248                    let headers = response.headers().clone();
249
250                    // Logging with enabled check to avoid string formatting when disabled
251                    if tracing::enabled!(tracing::Level::INFO) {
252                        // Removed info logging from hot path for performance
253                        // Use DEBUG level if needed for troubleshooting
254                    }
255
256                    // Stream response body directly (no buffering)
257                    let mut builder = Response::builder().status(status);
258
259                    // Copy all headers from backend, filtering out hop-by-hop headers
260                    // Hop-by-hop headers should not be proxied per RFC 7230
261                    // Also remove Content-Length to let hyper handle chunked encoding
262                    for (name, value) in headers.iter() {
263                        if !is_hop_header(name) && name != header::CONTENT_LENGTH {
264                            builder = builder.header(name, value);
265                        }
266                    }
267
268                    // Extract streaming body and convert to BoxBody
269                    let (_, incoming_body) = response.into_parts();
270                    let boxed: ResponseBody = incoming_body
271                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
272                        .boxed();
273
274                    Ok(builder.body(boxed).unwrap())
275                }
276                Ok(Err(e)) => {
277                    // Backend unavailable - return 502 Bad Gateway
278                    error!("Backend connection failed: {:?}", e);
279
280                    if e.is_connect() {
281                        error!("   Reason: Connection refused - backend unavailable");
282                    } else {
283                        error!("   Reason: Other connection error");
284                    }
285
286                    Ok(error_response(
287                        StatusCode::BAD_GATEWAY,
288                        "Backend service unavailable",
289                    ))
290                }
291                Err(_) => {
292                    // Timeout - return 504 Gateway Timeout
293                    error!("Backend request timed out after 30 seconds");
294
295                    Ok(error_response(
296                        StatusCode::GATEWAY_TIMEOUT,
297                        "Backend request timed out",
298                    ))
299                }
300            }
301        }
302    }
303}
304
305/// Creates HTTP response with error
306fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
307    let body = format!(
308        r#"<!DOCTYPE html>
309        <html>
310        <head><title>{} {}</title></head>
311        <body>
312        <h1>{} {}</h1>
313        <p>{}</p>
314        <hr>
315        <p><em>Rust Proxy Server</em></p>
316        </body>
317        </html>"#,
318        status.as_u16(),
319        status.canonical_reason().unwrap_or("Error"),
320        status.as_u16(),
321        status.canonical_reason().unwrap_or("Error"),
322        message
323    );
324
325    let full = Full::new(Bytes::from(body));
326    let boxed: ResponseBody = full
327        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
328        .boxed();
329
330    Response::builder()
331        .status(status)
332        .header("Content-Type", "text/html; charset=utf-8")
333        .body(boxed)
334        .unwrap()
335}
336
337/// Match path against pattern (supports wildcard *)
338/// Returns Some(remaining_path) if match, None otherwise
339pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
340    if let Some(prefix) = pattern.strip_suffix("/*") {
341        if path.starts_with(prefix) {
342            // Remove prefix and return remaining path
343            let remaining = path.strip_prefix(prefix).unwrap_or(path);
344            Some(remaining.to_string())
345        } else {
346            None
347        }
348    } else if pattern == path {
349        Some("/".to_string()) // Exact match, send root
350    } else {
351        None
352    }
353}