sentinel_proxy/proxy/
http_trait.rs

1//! ProxyHttp trait implementation for SentinelProxy.
2//!
3//! This module contains the Pingora ProxyHttp trait implementation which defines
4//! the core request/response lifecycle handling.
5
6use async_trait::async_trait;
7use pingora::http::ResponseHeader;
8use pingora::prelude::*;
9use pingora::proxy::{ProxyHttp, Session};
10use pingora::upstreams::peer::Peer;
11use std::time::Duration;
12use tracing::{debug, error, info, trace, warn};
13
14use crate::logging::AccessLogEntry;
15use crate::routing::RequestInfo;
16
17use super::context::RequestContext;
18use super::SentinelProxy;
19
20#[async_trait]
21impl ProxyHttp for SentinelProxy {
22    type CTX = RequestContext;
23
24    fn new_ctx(&self) -> Self::CTX {
25        RequestContext::new()
26    }
27
28    fn fail_to_connect(
29        &self,
30        _session: &mut Session,
31        peer: &HttpPeer,
32        ctx: &mut Self::CTX,
33        e: Box<Error>,
34    ) -> Box<Error> {
35        error!(
36            correlation_id = %ctx.trace_id,
37            route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
38            upstream = ctx.upstream.as_deref().unwrap_or("unknown"),
39            peer_address = %peer.address(),
40            error = %e,
41            "Failed to connect to upstream peer"
42        );
43        // Custom error pages are handled in response_filter
44        e
45    }
46
47    async fn upstream_peer(
48        &self,
49        session: &mut Session,
50        ctx: &mut Self::CTX,
51    ) -> Result<Box<HttpPeer>, Box<Error>> {
52        // Track active request
53        self.reload_coordinator.inc_requests();
54
55        // Initialize trace ID
56        ctx.trace_id = self.get_trace_id(session);
57
58        // Cache client address for logging
59        ctx.client_ip = session
60            .client_addr()
61            .map(|a| a.to_string())
62            .unwrap_or_else(|| "unknown".to_string());
63
64        trace!(
65            correlation_id = %ctx.trace_id,
66            client_ip = %ctx.client_ip,
67            "Request received, initializing context"
68        );
69
70        let req_header = session.req_header();
71
72        // Cache request info for access logging
73        ctx.method = req_header.method.to_string();
74        ctx.path = req_header.uri.path().to_string();
75        ctx.query = req_header.uri.query().map(|q| q.to_string());
76        ctx.host = req_header
77            .headers
78            .get("host")
79            .and_then(|v| v.to_str().ok())
80            .map(|s| s.to_string());
81        ctx.user_agent = req_header
82            .headers
83            .get("user-agent")
84            .and_then(|v| v.to_str().ok())
85            .map(|s| s.to_string());
86        ctx.referer = req_header
87            .headers
88            .get("referer")
89            .and_then(|v| v.to_str().ok())
90            .map(|s| s.to_string());
91
92        // Match route using sync RwLock (scoped to ensure lock is released before async ops)
93        let (route_match, route_duration) = {
94            let route_matcher = self.route_matcher.read();
95            let host = ctx.host.as_deref().unwrap_or("");
96
97            // Build request info (zero-copy for common case)
98            let mut request_info = RequestInfo::new(&ctx.method, &ctx.path, host);
99
100            // Only build headers HashMap if any route needs header matching
101            if route_matcher.needs_headers() {
102                request_info = request_info.with_headers(
103                    RequestInfo::build_headers(req_header.headers.iter())
104                );
105            }
106
107            // Only parse query params if any route needs query param matching
108            if route_matcher.needs_query_params() {
109                request_info = request_info.with_query_params(
110                    RequestInfo::parse_query_params(&ctx.path)
111                );
112            }
113
114            trace!(
115                correlation_id = %ctx.trace_id,
116                method = %request_info.method,
117                path = %request_info.path,
118                host = %request_info.host,
119                "Built request info for route matching"
120            );
121
122            let route_start = std::time::Instant::now();
123            let route_match = route_matcher
124                .match_request(&request_info)
125                .ok_or_else(|| {
126                    warn!(
127                        correlation_id = %ctx.trace_id,
128                        method = %request_info.method,
129                        path = %request_info.path,
130                        host = %request_info.host,
131                        "No matching route found for request"
132                    );
133                    Error::explain(ErrorType::InternalError, "No matching route found")
134                })?;
135            let route_duration = route_start.elapsed();
136            // Lock is dropped here when block ends
137            (route_match, route_duration)
138        };
139
140        ctx.route_id = Some(route_match.route_id.to_string());
141        ctx.route_config = Some(route_match.config.clone());
142
143        trace!(
144            correlation_id = %ctx.trace_id,
145            route_id = %route_match.route_id,
146            route_duration_us = route_duration.as_micros(),
147            service_type = ?route_match.config.service_type,
148            "Route matched"
149        );
150
151        // Check if this is a static file route
152        if route_match.config.service_type == sentinel_config::ServiceType::Static {
153            trace!(
154                correlation_id = %ctx.trace_id,
155                route_id = %route_match.route_id,
156                "Route type is static, checking for static server"
157            );
158            // Static routes don't need an upstream
159            if self
160                .static_servers
161                .get(route_match.route_id.as_str())
162                .await
163                .is_some()
164            {
165                // Mark this as a static route for later processing
166                ctx.upstream = Some(format!("_static_{}", route_match.route_id));
167                info!(
168                    correlation_id = %ctx.trace_id,
169                    route_id = %route_match.route_id,
170                    path = %ctx.path,
171                    "Serving static file"
172                );
173                // Return error to avoid upstream connection for static routes
174                return Err(Error::explain(
175                    ErrorType::InternalError,
176                    "Static file serving handled in request_filter",
177                ));
178            }
179        }
180
181        // Regular route with upstream
182        if let Some(ref upstream) = route_match.config.upstream {
183            ctx.upstream = Some(upstream.clone());
184            trace!(
185                correlation_id = %ctx.trace_id,
186                route_id = %route_match.route_id,
187                upstream = %upstream,
188                "Upstream configured for route"
189            );
190        } else {
191            error!(
192                correlation_id = %ctx.trace_id,
193                route_id = %route_match.route_id,
194                "Route has no upstream configured"
195            );
196            return Err(Error::explain(
197                ErrorType::InternalError,
198                format!(
199                    "Route '{}' has no upstream configured",
200                    route_match.route_id
201                ),
202            ));
203        }
204
205        info!(
206            correlation_id = %ctx.trace_id,
207            route_id = %route_match.route_id,
208            upstream = ?ctx.upstream,
209            method = %req_header.method,
210            path = %req_header.uri.path(),
211            host = ctx.host.as_deref().unwrap_or("-"),
212            client_ip = %ctx.client_ip,
213            "Processing request"
214        );
215
216        // Get upstream pool (skip for static routes)
217        if ctx
218            .upstream
219            .as_ref()
220            .is_some_and(|u| u.starts_with("_static_"))
221        {
222            // Static routes are handled in request_filter, should not reach here
223            return Err(Error::explain(
224                ErrorType::InternalError,
225                "Static route should be handled in request_filter",
226            ));
227        }
228
229        let upstream_name = ctx
230            .upstream
231            .as_ref()
232            .ok_or_else(|| Error::explain(ErrorType::InternalError, "No upstream configured"))?;
233
234        trace!(
235            correlation_id = %ctx.trace_id,
236            upstream = %upstream_name,
237            "Looking up upstream pool"
238        );
239
240        let pool = self.upstream_pools.get(upstream_name).await.ok_or_else(|| {
241            error!(
242                correlation_id = %ctx.trace_id,
243                upstream = %upstream_name,
244                "Upstream pool not found"
245            );
246            Error::explain(
247                ErrorType::InternalError,
248                format!("Upstream pool '{}' not found", upstream_name),
249            )
250        })?;
251
252        // Select peer from pool with retries
253        let max_retries = route_match
254            .config
255            .retry_policy
256            .as_ref()
257            .map(|r| r.max_attempts)
258            .unwrap_or(1);
259
260        trace!(
261            correlation_id = %ctx.trace_id,
262            upstream = %upstream_name,
263            max_retries = max_retries,
264            "Starting upstream peer selection"
265        );
266
267        let mut last_error = None;
268        let selection_start = std::time::Instant::now();
269
270        for attempt in 1..=max_retries {
271            ctx.upstream_attempts = attempt;
272
273            trace!(
274                correlation_id = %ctx.trace_id,
275                upstream = %upstream_name,
276                attempt = attempt,
277                max_retries = max_retries,
278                "Attempting to select upstream peer"
279            );
280
281            match pool.select_peer(None).await {
282                Ok(peer) => {
283                    let selection_duration = selection_start.elapsed();
284                    debug!(
285                        correlation_id = %ctx.trace_id,
286                        upstream = %upstream_name,
287                        peer_address = %peer.address(),
288                        attempt = attempt,
289                        selection_duration_us = selection_duration.as_micros(),
290                        "Selected upstream peer"
291                    );
292                    return Ok(Box::new(peer));
293                }
294                Err(e) => {
295                    warn!(
296                        correlation_id = %ctx.trace_id,
297                        upstream = %upstream_name,
298                        attempt = attempt,
299                        max_retries = max_retries,
300                        error = %e,
301                        "Failed to select upstream peer"
302                    );
303                    last_error = Some(e);
304
305                    if attempt < max_retries {
306                        // Exponential backoff
307                        let backoff = Duration::from_millis(100 * 2_u64.pow(attempt - 1));
308                        trace!(
309                            correlation_id = %ctx.trace_id,
310                            backoff_ms = backoff.as_millis(),
311                            "Backing off before retry"
312                        );
313                        tokio::time::sleep(backoff).await;
314                    }
315                }
316            }
317        }
318
319        let selection_duration = selection_start.elapsed();
320        error!(
321            correlation_id = %ctx.trace_id,
322            upstream = %upstream_name,
323            attempts = max_retries,
324            selection_duration_ms = selection_duration.as_millis(),
325            last_error = ?last_error,
326            "All upstream selection attempts failed"
327        );
328
329        Err(Error::explain(
330            ErrorType::InternalError,
331            format!("All upstream attempts failed: {:?}", last_error),
332        ))
333    }
334
335    async fn request_filter(
336        &self,
337        session: &mut Session,
338        ctx: &mut Self::CTX,
339    ) -> Result<bool, Box<Error>> {
340        trace!(
341            correlation_id = %ctx.trace_id,
342            route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
343            "Starting request filter phase"
344        );
345
346        // Use cached route config from upstream_peer (avoids duplicate route matching)
347        // Handle static file and builtin routes
348        if let Some(route_config) = ctx.route_config.clone() {
349            if route_config.service_type == sentinel_config::ServiceType::Static {
350                trace!(
351                    correlation_id = %ctx.trace_id,
352                    route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
353                    "Handling static file route"
354                );
355                // Create a minimal RouteMatch for the handler
356                let route_match = crate::routing::RouteMatch {
357                    route_id: sentinel_common::RouteId::new(ctx.route_id.as_deref().unwrap_or("")),
358                    config: route_config.clone(),
359                    policies: route_config.policies.clone(),
360                };
361                return self.handle_static_route(session, ctx, &route_match).await;
362            } else if route_config.service_type == sentinel_config::ServiceType::Builtin {
363                trace!(
364                    correlation_id = %ctx.trace_id,
365                    route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
366                    builtin_handler = ?route_config.builtin_handler,
367                    "Handling builtin route"
368                );
369                // Create a minimal RouteMatch for the handler
370                let route_match = crate::routing::RouteMatch {
371                    route_id: sentinel_common::RouteId::new(ctx.route_id.as_deref().unwrap_or("")),
372                    config: route_config.clone(),
373                    policies: route_config.policies.clone(),
374                };
375                return self.handle_builtin_route(session, ctx, &route_match).await;
376            }
377        }
378
379        // API validation for API routes
380        if let Some(route_id) = ctx.route_id.clone() {
381            if let Some(validator) = self.validators.get(&route_id).await {
382                trace!(
383                    correlation_id = %ctx.trace_id,
384                    route_id = %route_id,
385                    "Running API schema validation"
386                );
387                if let Some(result) = self
388                    .validate_api_request(session, ctx, &route_id, &validator)
389                    .await?
390                {
391                    debug!(
392                        correlation_id = %ctx.trace_id,
393                        route_id = %route_id,
394                        validation_passed = result,
395                        "API validation complete"
396                    );
397                    return Ok(result);
398                }
399            }
400        }
401
402        // Get client address before mutable borrow
403        let client_addr = session
404            .client_addr()
405            .map(|a| format!("{}", a))
406            .unwrap_or_else(|| "unknown".to_string());
407        let client_port = session.client_addr().map(|_| 0).unwrap_or(0);
408
409        let req_header = session.req_header_mut();
410
411        // Add correlation ID header
412        req_header
413            .insert_header("X-Correlation-Id", &ctx.trace_id)
414            .ok();
415        req_header.insert_header("X-Forwarded-By", "Sentinel").ok();
416
417        // Get current config for limits
418        let config = self.config_manager.current();
419
420        trace!(
421            correlation_id = %ctx.trace_id,
422            "Checking request limits"
423        );
424
425        // Enforce header limits
426        let header_count = req_header.headers.len();
427        if header_count > config.limits.max_header_count {
428            warn!(
429                correlation_id = %ctx.trace_id,
430                header_count = header_count,
431                limit = config.limits.max_header_count,
432                "Request blocked: exceeds header count limit"
433            );
434
435            self.metrics.record_blocked_request("header_count_exceeded");
436            return Err(Error::explain(ErrorType::InternalError, "Too many headers"));
437        }
438
439        // Check header size
440        let total_header_size: usize = req_header
441            .headers
442            .iter()
443            .map(|(k, v)| k.as_str().len() + v.len())
444            .sum();
445
446        if total_header_size > config.limits.max_header_size_bytes {
447            warn!(
448                correlation_id = %ctx.trace_id,
449                header_size = total_header_size,
450                limit = config.limits.max_header_size_bytes,
451                "Request blocked: exceeds header size limit"
452            );
453
454            self.metrics.record_blocked_request("header_size_exceeded");
455            return Err(Error::explain(
456                ErrorType::InternalError,
457                "Headers too large",
458            ));
459        }
460
461        trace!(
462            correlation_id = %ctx.trace_id,
463            header_count = header_count,
464            header_size = total_header_size,
465            "Request limits check passed"
466        );
467
468        // Process through external agents
469        trace!(
470            correlation_id = %ctx.trace_id,
471            "Processing request through agents"
472        );
473        self.process_agents(session, ctx, &client_addr, client_port)
474            .await?;
475
476        trace!(
477            correlation_id = %ctx.trace_id,
478            "Request filter phase complete, forwarding to upstream"
479        );
480
481        Ok(false) // Continue processing
482    }
483
484    async fn response_filter(
485        &self,
486        _session: &mut Session,
487        upstream_response: &mut ResponseHeader,
488        ctx: &mut Self::CTX,
489    ) -> Result<(), Box<Error>> {
490        let status = upstream_response.status.as_u16();
491        let duration = ctx.elapsed();
492
493        trace!(
494            correlation_id = %ctx.trace_id,
495            status = status,
496            "Starting response filter phase"
497        );
498
499        // Apply security headers
500        trace!(
501            correlation_id = %ctx.trace_id,
502            "Applying security headers"
503        );
504        self.apply_security_headers(upstream_response).ok();
505
506        // Add correlation ID to response
507        upstream_response.insert_header("X-Correlation-Id", &ctx.trace_id)?;
508
509        // Generate custom error pages for error responses
510        if status >= 400 {
511            trace!(
512                correlation_id = %ctx.trace_id,
513                status = status,
514                "Handling error response"
515            );
516            self.handle_error_response(upstream_response, ctx).await?;
517        }
518
519        // Record metrics
520        self.metrics.record_request(
521            ctx.route_id.as_deref().unwrap_or("unknown"),
522            &ctx.method,
523            status,
524            duration,
525        );
526
527        // Record passive health check
528        if let Some(ref upstream) = ctx.upstream {
529            let success = status < 500;
530
531            trace!(
532                correlation_id = %ctx.trace_id,
533                upstream = %upstream,
534                success = success,
535                status = status,
536                "Recording passive health check result"
537            );
538
539            self.passive_health.record_outcome(upstream, success).await;
540
541            // Report to upstream pool
542            if let Some(pool) = self.upstream_pools.get(upstream).await {
543                pool.report_result(upstream, success).await;
544            }
545
546            if !success {
547                warn!(
548                    correlation_id = %ctx.trace_id,
549                    upstream = %upstream,
550                    status = status,
551                    "Upstream returned error status"
552                );
553            }
554        }
555
556        // Final request completion log
557        if status >= 500 {
558            error!(
559                correlation_id = %ctx.trace_id,
560                route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
561                upstream = ctx.upstream.as_deref().unwrap_or("none"),
562                method = %ctx.method,
563                path = %ctx.path,
564                status = status,
565                duration_ms = duration.as_millis(),
566                attempts = ctx.upstream_attempts,
567                "Request completed with server error"
568            );
569        } else if status >= 400 {
570            warn!(
571                correlation_id = %ctx.trace_id,
572                route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
573                upstream = ctx.upstream.as_deref().unwrap_or("none"),
574                method = %ctx.method,
575                path = %ctx.path,
576                status = status,
577                duration_ms = duration.as_millis(),
578                "Request completed with client error"
579            );
580        } else {
581            info!(
582                correlation_id = %ctx.trace_id,
583                route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
584                upstream = ctx.upstream.as_deref().unwrap_or("none"),
585                method = %ctx.method,
586                path = %ctx.path,
587                status = status,
588                duration_ms = duration.as_millis(),
589                attempts = ctx.upstream_attempts,
590                "Request completed"
591            );
592        }
593
594        Ok(())
595    }
596
597    async fn logging(&self, session: &mut Session, _error: Option<&Error>, ctx: &mut Self::CTX) {
598        // Decrement active requests
599        self.reload_coordinator.dec_requests();
600
601        let duration = ctx.elapsed();
602
603        // Get response status
604        let status = session
605            .response_written()
606            .map(|r| r.status.as_u16())
607            .unwrap_or(0);
608
609        // Write to access log file if configured
610        if self.log_manager.access_log_enabled() {
611            let access_entry = AccessLogEntry {
612                timestamp: chrono::Utc::now().to_rfc3339(),
613                trace_id: ctx.trace_id.clone(),
614                method: ctx.method.clone(),
615                path: ctx.path.clone(),
616                query: ctx.query.clone(),
617                protocol: "HTTP/1.1".to_string(),
618                status,
619                body_bytes: ctx.response_bytes,
620                duration_ms: duration.as_millis() as u64,
621                client_ip: ctx.client_ip.clone(),
622                user_agent: ctx.user_agent.clone(),
623                referer: ctx.referer.clone(),
624                host: ctx.host.clone(),
625                route_id: ctx.route_id.clone(),
626                upstream: ctx.upstream.clone(),
627                upstream_attempts: ctx.upstream_attempts,
628                instance_id: self.app_state.instance_id.clone(),
629            };
630            self.log_manager.log_access(&access_entry);
631        }
632
633        // Log to tracing at debug level (avoid allocations if debug disabled)
634        if tracing::enabled!(tracing::Level::DEBUG) {
635            debug!(
636                trace_id = %ctx.trace_id,
637                method = %ctx.method,
638                path = %ctx.path,
639                route_id = ?ctx.route_id,
640                upstream = ?ctx.upstream,
641                status = status,
642                duration_ms = duration.as_millis() as u64,
643                upstream_attempts = ctx.upstream_attempts,
644                error = ?_error.map(|e| e.to_string()),
645                "Request completed"
646            );
647        }
648    }
649}