sentinel_proxy/proxy/
http_trait.rs

1//! ProxyHttp trait implementation for SentinelProxy.
2//!
3//! This module contains the Pingora ProxyHttp trait implementation which defines
4//! the core request/response lifecycle handling.
5
6use async_trait::async_trait;
7use pingora::http::ResponseHeader;
8use pingora::prelude::*;
9use pingora::proxy::{ProxyHttp, Session};
10use std::collections::HashMap;
11use std::time::Duration;
12use tracing::{debug, info, warn};
13
14use crate::logging::AccessLogEntry;
15use crate::routing::RequestInfo;
16
17use super::context::RequestContext;
18use super::SentinelProxy;
19
20#[async_trait]
21impl ProxyHttp for SentinelProxy {
22    type CTX = RequestContext;
23
24    fn new_ctx(&self) -> Self::CTX {
25        RequestContext::new()
26    }
27
28    fn fail_to_connect(
29        &self,
30        _session: &mut Session,
31        _peer: &HttpPeer,
32        _ctx: &mut Self::CTX,
33        e: Box<Error>,
34    ) -> Box<Error> {
35        // Log and return the error
36        // Custom error pages are handled in response_filter
37        e
38    }
39
40    async fn upstream_peer(
41        &self,
42        session: &mut Session,
43        ctx: &mut Self::CTX,
44    ) -> Result<Box<HttpPeer>, Box<Error>> {
45        // Track active request
46        self.reload_coordinator.inc_requests();
47
48        // Initialize trace ID
49        ctx.trace_id = self.get_trace_id(session);
50
51        // Cache client address for logging
52        ctx.client_ip = session
53            .client_addr()
54            .map(|a| a.to_string())
55            .unwrap_or_else(|| "unknown".to_string());
56
57        let req_header = session.req_header();
58
59        // Cache request info for access logging
60        ctx.method = req_header.method.to_string();
61        ctx.path = req_header.uri.path().to_string();
62        ctx.query = req_header.uri.query().map(|q| q.to_string());
63        ctx.host = req_header
64            .headers
65            .get("host")
66            .and_then(|v| v.to_str().ok())
67            .map(|s| s.to_string());
68        ctx.user_agent = req_header
69            .headers
70            .get("user-agent")
71            .and_then(|v| v.to_str().ok())
72            .map(|s| s.to_string());
73        ctx.referer = req_header
74            .headers
75            .get("referer")
76            .and_then(|v| v.to_str().ok())
77            .map(|s| s.to_string());
78
79        // Build request info for routing
80        let mut headers = HashMap::new();
81        for (name, value) in req_header.headers.iter() {
82            if let Ok(value_str) = value.to_str() {
83                headers.insert(name.as_str().to_lowercase(), value_str.to_string());
84            }
85        }
86
87        let request_info = RequestInfo {
88            method: ctx.method.clone(),
89            path: ctx.path.clone(),
90            host: ctx.host.clone().unwrap_or_default(),
91            headers,
92            query_params: RequestInfo::parse_query_params(&ctx.path),
93        };
94
95        // Match route
96        let route_match = self
97            .route_matcher
98            .read()
99            .await
100            .match_request(&request_info)
101            .ok_or_else(|| Error::explain(ErrorType::InternalError, "No matching route found"))?;
102
103        ctx.route_id = Some(route_match.route_id.to_string());
104
105        // Check if this is a static file route
106        if route_match.config.service_type == sentinel_config::ServiceType::Static {
107            // Static routes don't need an upstream
108            if self
109                .static_servers
110                .get(route_match.route_id.as_str())
111                .await
112                .is_some()
113            {
114                // Mark this as a static route for later processing
115                ctx.upstream = Some(format!("_static_{}", route_match.route_id));
116                debug!(
117                    correlation_id = %ctx.trace_id,
118                    route_id = %route_match.route_id,
119                    "Route is configured for static file serving"
120                );
121                // Return error to avoid upstream connection for static routes
122                return Err(Error::explain(
123                    ErrorType::InternalError,
124                    "Static file serving handled in request_filter",
125                ));
126            }
127        }
128
129        // Regular route with upstream
130        if let Some(ref upstream) = route_match.config.upstream {
131            ctx.upstream = Some(upstream.clone());
132        } else {
133            return Err(Error::explain(
134                ErrorType::InternalError,
135                format!(
136                    "Route '{}' has no upstream configured",
137                    route_match.route_id
138                ),
139            ));
140        }
141
142        info!(
143            correlation_id = %ctx.trace_id,
144            route_id = %route_match.route_id,
145            upstream = ?ctx.upstream,
146            method = %req_header.method,
147            path = %req_header.uri.path(),
148            "Request matched to route"
149        );
150
151        // Get upstream pool (skip for static routes)
152        if ctx
153            .upstream
154            .as_ref()
155            .is_some_and(|u| u.starts_with("_static_"))
156        {
157            // Static routes are handled in request_filter, should not reach here
158            return Err(Error::explain(
159                ErrorType::InternalError,
160                "Static route should be handled in request_filter",
161            ));
162        }
163
164        let upstream_name = ctx
165            .upstream
166            .as_ref()
167            .ok_or_else(|| Error::explain(ErrorType::InternalError, "No upstream configured"))?;
168        let pool = self.upstream_pools.get(upstream_name).await.ok_or_else(|| {
169            Error::explain(
170                ErrorType::InternalError,
171                format!("Upstream pool '{}' not found", upstream_name),
172            )
173        })?;
174
175        // Select peer from pool with retries
176        let max_retries = route_match
177            .config
178            .retry_policy
179            .as_ref()
180            .map(|r| r.max_attempts)
181            .unwrap_or(1);
182
183        let mut last_error = None;
184        for attempt in 1..=max_retries {
185            ctx.upstream_attempts = attempt;
186
187            match pool.select_peer(None).await {
188                Ok(peer) => {
189                    debug!(
190                        correlation_id = %ctx.trace_id,
191                        attempt = attempt,
192                        "Selected upstream peer"
193                    );
194                    return Ok(Box::new(peer));
195                }
196                Err(e) => {
197                    warn!(
198                        correlation_id = %ctx.trace_id,
199                        attempt = attempt,
200                        error = %e,
201                        "Failed to select upstream peer"
202                    );
203                    last_error = Some(e);
204
205                    if attempt < max_retries {
206                        // Exponential backoff
207                        let backoff = Duration::from_millis(100 * 2_u64.pow(attempt - 1));
208                        tokio::time::sleep(backoff).await;
209                    }
210                }
211            }
212        }
213
214        Err(Error::explain(
215            ErrorType::InternalError,
216            format!("All upstream attempts failed: {:?}", last_error),
217        ))
218    }
219
220    async fn request_filter(
221        &self,
222        session: &mut Session,
223        ctx: &mut Self::CTX,
224    ) -> Result<bool, Box<Error>> {
225        // First, determine the route for this request (needed before upstream_peer)
226        let req_header = session.req_header();
227        let route_info = {
228            let mut headers = HashMap::new();
229            for (name, value) in req_header.headers.iter() {
230                if let Ok(value_str) = value.to_str() {
231                    headers.insert(name.as_str().to_lowercase(), value_str.to_string());
232                }
233            }
234            let host = headers.get("host").cloned().unwrap_or_default();
235            let request_info = RequestInfo {
236                path: req_header.uri.path().to_string(),
237                method: req_header.method.as_str().to_string(),
238                host,
239                headers,
240                query_params: HashMap::new(),
241            };
242            self.route_matcher.read().await.match_request(&request_info)
243        };
244
245        // Handle static file routes
246        if let Some(route_match) = &route_info {
247            if route_match.config.service_type == sentinel_config::ServiceType::Static {
248                return self.handle_static_route(session, ctx, route_match).await;
249            } else if route_match.config.service_type == sentinel_config::ServiceType::Builtin {
250                return self.handle_builtin_route(session, ctx, route_match).await;
251            }
252        }
253
254        // API validation for API routes
255        if let Some(route_id) = ctx.route_id.clone() {
256            if let Some(validator) = self.validators.get(&route_id).await {
257                if let Some(result) = self
258                    .validate_api_request(session, ctx, &route_id, &validator)
259                    .await?
260                {
261                    return Ok(result);
262                }
263            }
264        }
265
266        // Get client address before mutable borrow
267        let client_addr = session
268            .client_addr()
269            .map(|a| format!("{}", a))
270            .unwrap_or_else(|| "unknown".to_string());
271        let client_port = session.client_addr().map(|_| 0).unwrap_or(0);
272
273        let req_header = session.req_header_mut();
274
275        // Add correlation ID header
276        req_header
277            .insert_header("X-Correlation-Id", &ctx.trace_id)
278            .ok();
279        req_header.insert_header("X-Forwarded-By", "Sentinel").ok();
280
281        // Get current config for limits
282        let config = self.config_manager.current();
283
284        // Enforce header limits
285        if req_header.headers.len() > config.limits.max_header_count {
286            warn!(
287                correlation_id = %ctx.trace_id,
288                header_count = req_header.headers.len(),
289                limit = config.limits.max_header_count,
290                "Request exceeds header count limit"
291            );
292
293            self.metrics.record_blocked_request("header_count_exceeded");
294            return Err(Error::explain(ErrorType::InternalError, "Too many headers"));
295        }
296
297        // Check header size
298        let total_header_size: usize = req_header
299            .headers
300            .iter()
301            .map(|(k, v)| k.as_str().len() + v.len())
302            .sum();
303
304        if total_header_size > config.limits.max_header_size_bytes {
305            warn!(
306                correlation_id = %ctx.trace_id,
307                header_size = total_header_size,
308                limit = config.limits.max_header_size_bytes,
309                "Request exceeds header size limit"
310            );
311
312            self.metrics.record_blocked_request("header_size_exceeded");
313            return Err(Error::explain(
314                ErrorType::InternalError,
315                "Headers too large",
316            ));
317        }
318
319        // Process through external agents
320        self.process_agents(session, ctx, &client_addr, client_port)
321            .await?;
322
323        Ok(false) // Continue processing
324    }
325
326    async fn response_filter(
327        &self,
328        _session: &mut Session,
329        upstream_response: &mut ResponseHeader,
330        ctx: &mut Self::CTX,
331    ) -> Result<(), Box<Error>> {
332        // Apply security headers
333        self.apply_security_headers(upstream_response).ok();
334
335        // Add correlation ID to response
336        upstream_response.insert_header("X-Correlation-Id", &ctx.trace_id)?;
337
338        // Record metrics
339        let status = upstream_response.status.as_u16();
340        let duration = ctx.elapsed();
341
342        // Generate custom error pages for error responses
343        if status >= 400 {
344            self.handle_error_response(upstream_response, ctx).await?;
345        }
346
347        self.metrics.record_request(
348            ctx.route_id.as_deref().unwrap_or("unknown"),
349            "GET", // TODO: Get actual method from context
350            status,
351            duration,
352        );
353
354        // Record passive health check
355        if let Some(ref upstream) = ctx.upstream {
356            let success = status < 500;
357            self.passive_health.record_outcome(upstream, success).await;
358
359            // Report to upstream pool
360            if let Some(pool) = self.upstream_pools.get(upstream).await {
361                pool.report_result(upstream, success).await;
362            }
363        }
364
365        info!(
366            correlation_id = %ctx.trace_id,
367            route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
368            upstream = ctx.upstream.as_deref().unwrap_or("unknown"),
369            status = status,
370            duration_ms = duration.as_millis(),
371            attempts = ctx.upstream_attempts,
372            "Request completed"
373        );
374
375        Ok(())
376    }
377
378    async fn logging(&self, session: &mut Session, _error: Option<&Error>, ctx: &mut Self::CTX) {
379        // Decrement active requests
380        self.reload_coordinator.dec_requests();
381
382        let duration = ctx.elapsed();
383
384        // Get response status
385        let status = session
386            .response_written()
387            .map(|r| r.status.as_u16())
388            .unwrap_or(0);
389
390        // Write to access log file if configured
391        if self.log_manager.access_log_enabled() {
392            let access_entry = AccessLogEntry {
393                timestamp: chrono::Utc::now().to_rfc3339(),
394                trace_id: ctx.trace_id.clone(),
395                method: ctx.method.clone(),
396                path: ctx.path.clone(),
397                query: ctx.query.clone(),
398                protocol: "HTTP/1.1".to_string(),
399                status,
400                body_bytes: ctx.response_bytes,
401                duration_ms: duration.as_millis() as u64,
402                client_ip: ctx.client_ip.clone(),
403                user_agent: ctx.user_agent.clone(),
404                referer: ctx.referer.clone(),
405                host: ctx.host.clone(),
406                route_id: ctx.route_id.clone(),
407                upstream: ctx.upstream.clone(),
408                upstream_attempts: ctx.upstream_attempts,
409                instance_id: self.app_state.instance_id.clone(),
410            };
411            self.log_manager.log_access(&access_entry);
412        }
413
414        // Also log to stdout for tracing
415        let log_entry = serde_json::json!({
416            "timestamp": chrono::Utc::now().to_rfc3339(),
417            "trace_id": ctx.trace_id,
418            "instance_id": self.app_state.instance_id,
419            "method": ctx.method,
420            "path": ctx.path,
421            "route_id": ctx.route_id,
422            "upstream": ctx.upstream,
423            "status": status,
424            "duration_ms": duration.as_millis(),
425            "upstream_attempts": ctx.upstream_attempts,
426            "error": _error.map(|e| e.to_string()),
427        });
428
429        debug!("{}", log_entry);
430    }
431}