warpdrive_proxy/middleware/mod.rs
1//! Middleware system for WarpDrive
2//!
3//! This module provides a composable middleware architecture adapted to Pingora's
4//! filter-based model. Each middleware operates on requests and responses as they
5//! flow through the proxy.
6//!
7//! # Architecture
8//!
9//! Middlewares are applied in the following order:
10//! 1. **Headers** - Inject X-Forwarded-* headers (request_filter)
11//! 2. **Logging** - Log request/response metrics (request_filter + response_filter)
12//! 3. **Sendfile** - X-Sendfile/X-Accel-Redirect support (response_filter)
13//! 4. **Compression** - Response compression (response_filter)
14//!
15//! # Usage with Pingora
16//!
17//! ```no_run
18//! use warpdrive::middleware::MiddlewareStack;
19//! use pingora::proxy::Session;
20//!
21//! // In ProxyHttp implementation:
22//! async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<()> {
23//! self.middleware.apply_request_filters(session, ctx).await
24//! }
25//!
26//! async fn response_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<()> {
27//! self.middleware.apply_response_filters(session, ctx).await
28//! }
29//! ```
30
31pub mod circuit_breaker;
32pub mod compression;
33pub mod concurrency;
34pub mod headers;
35pub mod logging;
36pub mod rate_limit;
37pub mod sendfile;
38pub mod static_files;
39pub mod trusted_ranges;
40
41#[cfg(test)]
42mod tests;
43
44use async_trait::async_trait;
45use bytes::Bytes;
46use pingora::http::ResponseHeader;
47use pingora::prelude::*;
48use std::path::PathBuf;
49use std::sync::Arc;
50
51use crate::config::Config;
52
53pub use circuit_breaker::CircuitBreakerMiddleware;
54pub use compression::CompressionMiddleware;
55pub use concurrency::ConcurrencyMiddleware;
56pub use headers::HeadersMiddleware;
57pub use logging::LoggingMiddleware;
58pub use rate_limit::RateLimitMiddleware;
59pub use sendfile::SendfileMiddleware;
60pub use static_files::StaticFilesMiddleware;
61pub use trusted_ranges::TrustedRangesMiddleware;
62
63/// Context for middleware execution
64///
65/// Stores state needed across middleware chain execution, including
66/// request start time, response metadata, and flags.
67pub struct MiddlewareContext {
68 /// Request start timestamp (for logging)
69 pub request_start: std::time::Instant,
70
71 /// Response status code (captured by logging middleware)
72 pub status_code: u16,
73
74 /// Response body size (captured by logging middleware)
75 pub body_size: usize,
76
77 /// Sendfile state (set by sendfile middleware)
78 pub sendfile: SendfileState,
79
80 /// Compression state (set by compression middleware)
81 pub compression: CompressionState,
82
83 /// Static response (set by static files middleware to short-circuit proxy)
84 pub static_response: Option<StaticResponse>,
85
86 /// Concurrency permit (held for request duration, auto-released on drop)
87 pub concurrency_permit: Option<tokio::sync::OwnedSemaphorePermit>,
88
89 /// Streaming flag (set when SSE or other streaming content is detected)
90 ///
91 /// When true, disables response buffering to allow real-time streaming.
92 /// Automatically set when:
93 /// - Content-Type: text/event-stream (Server-Sent Events)
94 /// - X-Accel-Buffering: no (nginx compatibility)
95 pub streaming: bool,
96
97 /// Trusted source flag (set by trusted_ranges middleware)
98 ///
99 /// When true, indicates request comes from a trusted proxy/CDN IP range.
100 /// Used to bypass rate limiting and concurrency limits.
101 pub trusted_source: bool,
102
103 /// Real client IP (normalized by trusted_ranges middleware)
104 ///
105 /// Contains the actual client IP extracted from proxy headers (if trusted)
106 /// or the socket IP (if not from trusted source). Used for:
107 /// - Rate limiting (to rate limit real clients, not proxies)
108 /// - Logging (to log actual client IPs)
109 /// - X-Forwarded-For headers (to maintain proxy chain)
110 pub real_client_ip: std::net::IpAddr,
111}
112
113impl Default for MiddlewareContext {
114 fn default() -> Self {
115 Self {
116 request_start: std::time::Instant::now(),
117 status_code: 200,
118 body_size: 0,
119 sendfile: SendfileState::default(),
120 compression: CompressionState::default(),
121 static_response: None,
122 concurrency_permit: None,
123 streaming: false,
124 trusted_source: false,
125 real_client_ip: "0.0.0.0".parse().unwrap(),
126 }
127 }
128}
129
130/// Body payload for a static response
131#[derive(Debug)]
132pub enum StaticResponseBody {
133 /// In-memory buffer (small files)
134 InMemory(Bytes),
135 /// Streamed from disk (large files)
136 Stream(PathBuf),
137}
138
139/// Pre-computed static response returned by the static files middleware
140#[derive(Debug)]
141pub struct StaticResponse {
142 pub header: ResponseHeader,
143 pub body: StaticResponseBody,
144}
145
146/// Tracks sendfile middleware state for a response
147#[derive(Debug, Default)]
148pub struct SendfileState {
149 /// Whether sendfile is engaged for this response
150 pub active: bool,
151 /// Optional path for logging/diagnostics
152 pub path: Option<String>,
153 /// In-memory body to serve (loaded once when active)
154 pub body: Option<Bytes>,
155 /// Flag indicating if body has been emitted downstream
156 pub served: bool,
157}
158
159impl SendfileState {
160 pub fn activate(&mut self, path: String, body: Bytes) {
161 self.active = true;
162 self.path = Some(path);
163 self.body = Some(body);
164 self.served = false;
165 }
166
167 pub fn reset(&mut self) {
168 self.active = false;
169 self.path = None;
170 self.body = None;
171 self.served = false;
172 }
173}
174
175/// Compression algorithm in use for the current response
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum CompressionEncoding {
178 /// Brotli compression (preferred, better ratio)
179 Brotli,
180 /// Gzip compression (fallback, wider support)
181 Gzip,
182}
183
184/// Tracks compression middleware state across body chunks
185#[derive(Debug, Default)]
186pub enum CompressionState {
187 /// Compression disabled for this response
188 #[default]
189 Disabled,
190 /// Compression enabled and buffering body chunks
191 Pending {
192 buffer: Vec<u8>,
193 encoding: CompressionEncoding,
194 },
195 /// Compression already applied
196 Complete,
197}
198
199impl CompressionState {
200 pub fn enable(&mut self, encoding: CompressionEncoding) {
201 *self = CompressionState::Pending {
202 buffer: Vec::new(),
203 encoding,
204 };
205 }
206
207 pub fn is_enabled(&self) -> bool {
208 matches!(self, CompressionState::Pending { .. })
209 }
210}
211
212/// Middleware stack that composes all middlewares
213///
214/// This struct holds all middleware components and orchestrates their execution
215/// in the correct order for both request and response phases.
216pub struct MiddlewareStack {
217 pub trusted_ranges: Option<TrustedRangesMiddleware>,
218 pub static_files: Option<StaticFilesMiddleware>,
219 pub concurrency: Option<ConcurrencyMiddleware>,
220 pub rate_limit: Option<RateLimitMiddleware>,
221 pub circuit_breaker: Option<CircuitBreakerMiddleware>,
222 pub headers: HeadersMiddleware,
223 pub logging: Option<LoggingMiddleware>,
224 pub sendfile: Option<SendfileMiddleware>,
225 pub compression: Option<CompressionMiddleware>,
226}
227
228impl MiddlewareStack {
229 /// Create a new middleware stack from configuration
230 pub fn new(config: Arc<Config>) -> Self {
231 Self {
232 trusted_ranges: if config.trusted_ranges_file.is_some()
233 || config.client_ip_header.is_some()
234 {
235 TrustedRangesMiddleware::new(
236 config.trusted_ranges_file.clone(),
237 config.client_ip_header.clone(),
238 )
239 .ok()
240 } else {
241 None
242 },
243 static_files: if config.static_enabled {
244 Some(StaticFilesMiddleware::from_config(&config))
245 } else {
246 None
247 },
248 concurrency: if config.max_concurrent_requests > 0 {
249 Some(ConcurrencyMiddleware::new(
250 true,
251 config.max_concurrent_requests,
252 ))
253 } else {
254 None
255 },
256 rate_limit: if config.rate_limit_enabled {
257 Some(RateLimitMiddleware::new(
258 true,
259 config.rate_limit_requests_per_sec,
260 config.rate_limit_burst_size,
261 ))
262 } else {
263 None
264 },
265 circuit_breaker: if config.circuit_breaker_enabled {
266 Some(CircuitBreakerMiddleware::new(
267 true,
268 config.circuit_breaker_failure_threshold,
269 config.circuit_breaker_timeout_secs,
270 ))
271 } else {
272 None
273 },
274 headers: HeadersMiddleware::new(config.clone()),
275 logging: if config.log_requests {
276 Some(LoggingMiddleware::new())
277 } else {
278 None
279 },
280 sendfile: if config.x_sendfile_enabled {
281 Some(SendfileMiddleware::new())
282 } else {
283 None
284 },
285 compression: if config.gzip_compression_enabled {
286 Some(CompressionMiddleware::new())
287 } else {
288 None
289 },
290 }
291 }
292
293 /// Apply all request filters in order
294 ///
295 /// This is called from ProxyHttp::request_filter() to process the incoming request
296 /// before it's sent to the upstream server.
297 pub async fn apply_request_filters(
298 &self,
299 session: &mut Session,
300 ctx: &mut MiddlewareContext,
301 ) -> Result<()> {
302 // 0. Trusted ranges - normalize client IP and set trusted flag (must run FIRST)
303 if let Some(ref trusted_ranges) = self.trusted_ranges {
304 trusted_ranges.request_filter(session, ctx).await?;
305 }
306
307 // 1. Static files - serve directly if path matches (skip all other middleware + proxy)
308 if let Some(ref static_files) = self.static_files {
309 static_files.request_filter(session, ctx).await?;
310
311 // If static response was set, short-circuit the proxy
312 if ctx.static_response.is_some() {
313 return Ok(());
314 }
315 }
316
317 // 2. Concurrency limiting - limit total concurrent requests
318 if let Some(ref concurrency) = self.concurrency {
319 concurrency.request_filter(session, ctx).await?;
320 }
321
322 // 3. Rate limiting - check limits (fail fast)
323 if let Some(ref rate_limit) = self.rate_limit {
324 rate_limit.request_filter(session, ctx).await?;
325 }
326
327 // 4. Circuit breaker - check if upstream is available
328 if let Some(ref circuit_breaker) = self.circuit_breaker {
329 circuit_breaker.request_filter(session, ctx).await?;
330 }
331
332 // 5. Headers middleware - add X-Forwarded-* headers
333 self.headers.request_filter(session, ctx).await?;
334
335 // 6. Logging middleware - record request start
336 if let Some(ref logging) = self.logging {
337 logging.request_filter(session, ctx).await?;
338 }
339
340 Ok(())
341 }
342
343 /// Apply all response filters in order
344 ///
345 /// This is called from ProxyHttp::response_filter() to process the upstream response
346 /// before it's sent to the client.
347 pub async fn apply_response_filters(
348 &self,
349 session: &mut Session,
350 upstream_response: &mut ResponseHeader,
351 ctx: &mut MiddlewareContext,
352 ) -> Result<()> {
353 // 1. Circuit breaker - record upstream response (success/failure)
354 if let Some(ref circuit_breaker) = self.circuit_breaker {
355 circuit_breaker
356 .response_filter(session, upstream_response, ctx)
357 .await?;
358 }
359
360 // 2. Sendfile middleware - handle X-Sendfile headers
361 if let Some(ref sendfile) = self.sendfile {
362 sendfile
363 .response_filter(session, upstream_response, ctx)
364 .await?;
365 }
366
367 // 2. Compression middleware - compress response if applicable
368 if let Some(ref compression) = self.compression {
369 compression
370 .response_filter(session, upstream_response, ctx)
371 .await?;
372 }
373
374 // 3. Logging middleware - log response (must be last to capture all metadata)
375 if let Some(ref logging) = self.logging {
376 logging
377 .response_filter(session, upstream_response, ctx)
378 .await?;
379 }
380
381 Ok(())
382 }
383}
384
385/// Trait for middleware components
386///
387/// Each middleware implements this trait to define its behavior during
388/// request and response phases. Both methods are optional - middlewares
389/// can implement only the phase they care about.
390#[async_trait]
391pub trait Middleware: Send + Sync {
392 /// Process request before sending to upstream
393 async fn request_filter(
394 &self,
395 _session: &mut Session,
396 _ctx: &mut MiddlewareContext,
397 ) -> Result<()> {
398 Ok(())
399 }
400
401 /// Process response before sending to client
402 async fn response_filter(
403 &self,
404 _session: &mut Session,
405 _upstream_response: &mut ResponseHeader,
406 _ctx: &mut MiddlewareContext,
407 ) -> Result<()> {
408 Ok(())
409 }
410}