1use async_trait::async_trait;
7use pingora::http::ResponseHeader;
8use pingora::prelude::*;
9use pingora::proxy::{ProxyHttp, Session};
10use std::collections::HashMap;
11use std::time::Duration;
12use tracing::{debug, info, warn};
13
14use crate::logging::AccessLogEntry;
15use crate::routing::RequestInfo;
16
17use super::context::RequestContext;
18use super::SentinelProxy;
19
20#[async_trait]
21impl ProxyHttp for SentinelProxy {
22 type CTX = RequestContext;
23
24 fn new_ctx(&self) -> Self::CTX {
25 RequestContext::new()
26 }
27
28 fn fail_to_connect(
29 &self,
30 _session: &mut Session,
31 _peer: &HttpPeer,
32 _ctx: &mut Self::CTX,
33 e: Box<Error>,
34 ) -> Box<Error> {
35 e
38 }
39
40 async fn upstream_peer(
41 &self,
42 session: &mut Session,
43 ctx: &mut Self::CTX,
44 ) -> Result<Box<HttpPeer>, Box<Error>> {
45 self.reload_coordinator.inc_requests();
47
48 ctx.trace_id = self.get_trace_id(session);
50
51 ctx.client_ip = session
53 .client_addr()
54 .map(|a| a.to_string())
55 .unwrap_or_else(|| "unknown".to_string());
56
57 let req_header = session.req_header();
58
59 ctx.method = req_header.method.to_string();
61 ctx.path = req_header.uri.path().to_string();
62 ctx.query = req_header.uri.query().map(|q| q.to_string());
63 ctx.host = req_header
64 .headers
65 .get("host")
66 .and_then(|v| v.to_str().ok())
67 .map(|s| s.to_string());
68 ctx.user_agent = req_header
69 .headers
70 .get("user-agent")
71 .and_then(|v| v.to_str().ok())
72 .map(|s| s.to_string());
73 ctx.referer = req_header
74 .headers
75 .get("referer")
76 .and_then(|v| v.to_str().ok())
77 .map(|s| s.to_string());
78
79 let mut headers = HashMap::new();
81 for (name, value) in req_header.headers.iter() {
82 if let Ok(value_str) = value.to_str() {
83 headers.insert(name.as_str().to_lowercase(), value_str.to_string());
84 }
85 }
86
87 let request_info = RequestInfo {
88 method: ctx.method.clone(),
89 path: ctx.path.clone(),
90 host: ctx.host.clone().unwrap_or_default(),
91 headers,
92 query_params: RequestInfo::parse_query_params(&ctx.path),
93 };
94
95 let route_match = self
97 .route_matcher
98 .read()
99 .await
100 .match_request(&request_info)
101 .ok_or_else(|| Error::explain(ErrorType::InternalError, "No matching route found"))?;
102
103 ctx.route_id = Some(route_match.route_id.to_string());
104
105 if route_match.config.service_type == sentinel_config::ServiceType::Static {
107 if self
109 .static_servers
110 .get(route_match.route_id.as_str())
111 .await
112 .is_some()
113 {
114 ctx.upstream = Some(format!("_static_{}", route_match.route_id));
116 debug!(
117 correlation_id = %ctx.trace_id,
118 route_id = %route_match.route_id,
119 "Route is configured for static file serving"
120 );
121 return Err(Error::explain(
123 ErrorType::InternalError,
124 "Static file serving handled in request_filter",
125 ));
126 }
127 }
128
129 if let Some(ref upstream) = route_match.config.upstream {
131 ctx.upstream = Some(upstream.clone());
132 } else {
133 return Err(Error::explain(
134 ErrorType::InternalError,
135 format!(
136 "Route '{}' has no upstream configured",
137 route_match.route_id
138 ),
139 ));
140 }
141
142 info!(
143 correlation_id = %ctx.trace_id,
144 route_id = %route_match.route_id,
145 upstream = ?ctx.upstream,
146 method = %req_header.method,
147 path = %req_header.uri.path(),
148 "Request matched to route"
149 );
150
151 if ctx
153 .upstream
154 .as_ref()
155 .is_some_and(|u| u.starts_with("_static_"))
156 {
157 return Err(Error::explain(
159 ErrorType::InternalError,
160 "Static route should be handled in request_filter",
161 ));
162 }
163
164 let upstream_name = ctx
165 .upstream
166 .as_ref()
167 .ok_or_else(|| Error::explain(ErrorType::InternalError, "No upstream configured"))?;
168 let pool = self.upstream_pools.get(upstream_name).await.ok_or_else(|| {
169 Error::explain(
170 ErrorType::InternalError,
171 format!("Upstream pool '{}' not found", upstream_name),
172 )
173 })?;
174
175 let max_retries = route_match
177 .config
178 .retry_policy
179 .as_ref()
180 .map(|r| r.max_attempts)
181 .unwrap_or(1);
182
183 let mut last_error = None;
184 for attempt in 1..=max_retries {
185 ctx.upstream_attempts = attempt;
186
187 match pool.select_peer(None).await {
188 Ok(peer) => {
189 debug!(
190 correlation_id = %ctx.trace_id,
191 attempt = attempt,
192 "Selected upstream peer"
193 );
194 return Ok(Box::new(peer));
195 }
196 Err(e) => {
197 warn!(
198 correlation_id = %ctx.trace_id,
199 attempt = attempt,
200 error = %e,
201 "Failed to select upstream peer"
202 );
203 last_error = Some(e);
204
205 if attempt < max_retries {
206 let backoff = Duration::from_millis(100 * 2_u64.pow(attempt - 1));
208 tokio::time::sleep(backoff).await;
209 }
210 }
211 }
212 }
213
214 Err(Error::explain(
215 ErrorType::InternalError,
216 format!("All upstream attempts failed: {:?}", last_error),
217 ))
218 }
219
220 async fn request_filter(
221 &self,
222 session: &mut Session,
223 ctx: &mut Self::CTX,
224 ) -> Result<bool, Box<Error>> {
225 let req_header = session.req_header();
227 let route_info = {
228 let mut headers = HashMap::new();
229 for (name, value) in req_header.headers.iter() {
230 if let Ok(value_str) = value.to_str() {
231 headers.insert(name.as_str().to_lowercase(), value_str.to_string());
232 }
233 }
234 let host = headers.get("host").cloned().unwrap_or_default();
235 let request_info = RequestInfo {
236 path: req_header.uri.path().to_string(),
237 method: req_header.method.as_str().to_string(),
238 host,
239 headers,
240 query_params: HashMap::new(),
241 };
242 self.route_matcher.read().await.match_request(&request_info)
243 };
244
245 if let Some(route_match) = &route_info {
247 if route_match.config.service_type == sentinel_config::ServiceType::Static {
248 return self.handle_static_route(session, ctx, route_match).await;
249 } else if route_match.config.service_type == sentinel_config::ServiceType::Builtin {
250 return self.handle_builtin_route(session, ctx, route_match).await;
251 }
252 }
253
254 if let Some(route_id) = ctx.route_id.clone() {
256 if let Some(validator) = self.validators.get(&route_id).await {
257 if let Some(result) = self
258 .validate_api_request(session, ctx, &route_id, &validator)
259 .await?
260 {
261 return Ok(result);
262 }
263 }
264 }
265
266 let client_addr = session
268 .client_addr()
269 .map(|a| format!("{}", a))
270 .unwrap_or_else(|| "unknown".to_string());
271 let client_port = session.client_addr().map(|_| 0).unwrap_or(0);
272
273 let req_header = session.req_header_mut();
274
275 req_header
277 .insert_header("X-Correlation-Id", &ctx.trace_id)
278 .ok();
279 req_header.insert_header("X-Forwarded-By", "Sentinel").ok();
280
281 let config = self.config_manager.current();
283
284 if req_header.headers.len() > config.limits.max_header_count {
286 warn!(
287 correlation_id = %ctx.trace_id,
288 header_count = req_header.headers.len(),
289 limit = config.limits.max_header_count,
290 "Request exceeds header count limit"
291 );
292
293 self.metrics.record_blocked_request("header_count_exceeded");
294 return Err(Error::explain(ErrorType::InternalError, "Too many headers"));
295 }
296
297 let total_header_size: usize = req_header
299 .headers
300 .iter()
301 .map(|(k, v)| k.as_str().len() + v.len())
302 .sum();
303
304 if total_header_size > config.limits.max_header_size_bytes {
305 warn!(
306 correlation_id = %ctx.trace_id,
307 header_size = total_header_size,
308 limit = config.limits.max_header_size_bytes,
309 "Request exceeds header size limit"
310 );
311
312 self.metrics.record_blocked_request("header_size_exceeded");
313 return Err(Error::explain(
314 ErrorType::InternalError,
315 "Headers too large",
316 ));
317 }
318
319 self.process_agents(session, ctx, &client_addr, client_port)
321 .await?;
322
323 Ok(false) }
325
326 async fn response_filter(
327 &self,
328 _session: &mut Session,
329 upstream_response: &mut ResponseHeader,
330 ctx: &mut Self::CTX,
331 ) -> Result<(), Box<Error>> {
332 self.apply_security_headers(upstream_response).ok();
334
335 upstream_response.insert_header("X-Correlation-Id", &ctx.trace_id)?;
337
338 let status = upstream_response.status.as_u16();
340 let duration = ctx.elapsed();
341
342 if status >= 400 {
344 self.handle_error_response(upstream_response, ctx).await?;
345 }
346
347 self.metrics.record_request(
348 ctx.route_id.as_deref().unwrap_or("unknown"),
349 "GET", status,
351 duration,
352 );
353
354 if let Some(ref upstream) = ctx.upstream {
356 let success = status < 500;
357 self.passive_health.record_outcome(upstream, success).await;
358
359 if let Some(pool) = self.upstream_pools.get(upstream).await {
361 pool.report_result(upstream, success).await;
362 }
363 }
364
365 info!(
366 correlation_id = %ctx.trace_id,
367 route_id = ctx.route_id.as_deref().unwrap_or("unknown"),
368 upstream = ctx.upstream.as_deref().unwrap_or("unknown"),
369 status = status,
370 duration_ms = duration.as_millis(),
371 attempts = ctx.upstream_attempts,
372 "Request completed"
373 );
374
375 Ok(())
376 }
377
378 async fn logging(&self, session: &mut Session, _error: Option<&Error>, ctx: &mut Self::CTX) {
379 self.reload_coordinator.dec_requests();
381
382 let duration = ctx.elapsed();
383
384 let status = session
386 .response_written()
387 .map(|r| r.status.as_u16())
388 .unwrap_or(0);
389
390 if self.log_manager.access_log_enabled() {
392 let access_entry = AccessLogEntry {
393 timestamp: chrono::Utc::now().to_rfc3339(),
394 trace_id: ctx.trace_id.clone(),
395 method: ctx.method.clone(),
396 path: ctx.path.clone(),
397 query: ctx.query.clone(),
398 protocol: "HTTP/1.1".to_string(),
399 status,
400 body_bytes: ctx.response_bytes,
401 duration_ms: duration.as_millis() as u64,
402 client_ip: ctx.client_ip.clone(),
403 user_agent: ctx.user_agent.clone(),
404 referer: ctx.referer.clone(),
405 host: ctx.host.clone(),
406 route_id: ctx.route_id.clone(),
407 upstream: ctx.upstream.clone(),
408 upstream_attempts: ctx.upstream_attempts,
409 instance_id: self.app_state.instance_id.clone(),
410 };
411 self.log_manager.log_access(&access_entry);
412 }
413
414 let log_entry = serde_json::json!({
416 "timestamp": chrono::Utc::now().to_rfc3339(),
417 "trace_id": ctx.trace_id,
418 "instance_id": self.app_state.instance_id,
419 "method": ctx.method,
420 "path": ctx.path,
421 "route_id": ctx.route_id,
422 "upstream": ctx.upstream,
423 "status": status,
424 "duration_ms": duration.as_millis(),
425 "upstream_attempts": ctx.upstream_attempts,
426 "error": _error.map(|e| e.to_string()),
427 });
428
429 debug!("{}", log_entry);
430 }
431}