warpdrive_proxy/middleware/
mod.rs

1//! Middleware system for WarpDrive
2//!
3//! This module provides a composable middleware architecture adapted to Pingora's
4//! filter-based model. Each middleware operates on requests and responses as they
5//! flow through the proxy.
6//!
7//! # Architecture
8//!
9//! Middlewares are applied in the following order:
10//! 1. **Headers** - Inject X-Forwarded-* headers (request_filter)
11//! 2. **Logging** - Log request/response metrics (request_filter + response_filter)
12//! 3. **Sendfile** - X-Sendfile/X-Accel-Redirect support (response_filter)
13//! 4. **Compression** - Response compression (response_filter)
14//!
15//! # Usage with Pingora
16//!
17//! ```no_run
18//! use warpdrive::middleware::MiddlewareStack;
19//! use pingora::proxy::Session;
20//!
21//! // In ProxyHttp implementation:
22//! async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<()> {
23//!     self.middleware.apply_request_filters(session, ctx).await
24//! }
25//!
26//! async fn response_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<()> {
27//!     self.middleware.apply_response_filters(session, ctx).await
28//! }
29//! ```
30
31pub mod circuit_breaker;
32pub mod compression;
33pub mod concurrency;
34pub mod headers;
35pub mod logging;
36pub mod rate_limit;
37pub mod sendfile;
38pub mod static_files;
39pub mod trusted_ranges;
40
41#[cfg(test)]
42mod tests;
43
44use async_trait::async_trait;
45use bytes::Bytes;
46use pingora::http::ResponseHeader;
47use pingora::prelude::*;
48use std::path::PathBuf;
49use std::sync::Arc;
50
51use crate::config::Config;
52
53pub use circuit_breaker::CircuitBreakerMiddleware;
54pub use compression::CompressionMiddleware;
55pub use concurrency::ConcurrencyMiddleware;
56pub use headers::HeadersMiddleware;
57pub use logging::LoggingMiddleware;
58pub use rate_limit::RateLimitMiddleware;
59pub use sendfile::SendfileMiddleware;
60pub use static_files::StaticFilesMiddleware;
61pub use trusted_ranges::TrustedRangesMiddleware;
62
63/// Context for middleware execution
64///
65/// Stores state needed across middleware chain execution, including
66/// request start time, response metadata, and flags.
67pub struct MiddlewareContext {
68    /// Request start timestamp (for logging)
69    pub request_start: std::time::Instant,
70
71    /// Response status code (captured by logging middleware)
72    pub status_code: u16,
73
74    /// Response body size (captured by logging middleware)
75    pub body_size: usize,
76
77    /// Sendfile state (set by sendfile middleware)
78    pub sendfile: SendfileState,
79
80    /// Compression state (set by compression middleware)
81    pub compression: CompressionState,
82
83    /// Static response (set by static files middleware to short-circuit proxy)
84    pub static_response: Option<StaticResponse>,
85
86    /// Concurrency permit (held for request duration, auto-released on drop)
87    pub concurrency_permit: Option<tokio::sync::OwnedSemaphorePermit>,
88
89    /// Streaming flag (set when SSE or other streaming content is detected)
90    ///
91    /// When true, disables response buffering to allow real-time streaming.
92    /// Automatically set when:
93    /// - Content-Type: text/event-stream (Server-Sent Events)
94    /// - X-Accel-Buffering: no (nginx compatibility)
95    pub streaming: bool,
96
97    /// Trusted source flag (set by trusted_ranges middleware)
98    ///
99    /// When true, indicates request comes from a trusted proxy/CDN IP range.
100    /// Used to bypass rate limiting and concurrency limits.
101    pub trusted_source: bool,
102
103    /// Real client IP (normalized by trusted_ranges middleware)
104    ///
105    /// Contains the actual client IP extracted from proxy headers (if trusted)
106    /// or the socket IP (if not from trusted source). Used for:
107    /// - Rate limiting (to rate limit real clients, not proxies)
108    /// - Logging (to log actual client IPs)
109    /// - X-Forwarded-For headers (to maintain proxy chain)
110    pub real_client_ip: std::net::IpAddr,
111}
112
113impl Default for MiddlewareContext {
114    fn default() -> Self {
115        Self {
116            request_start: std::time::Instant::now(),
117            status_code: 200,
118            body_size: 0,
119            sendfile: SendfileState::default(),
120            compression: CompressionState::default(),
121            static_response: None,
122            concurrency_permit: None,
123            streaming: false,
124            trusted_source: false,
125            real_client_ip: "0.0.0.0".parse().unwrap(),
126        }
127    }
128}
129
130/// Body payload for a static response
131#[derive(Debug)]
132pub enum StaticResponseBody {
133    /// In-memory buffer (small files)
134    InMemory(Bytes),
135    /// Streamed from disk (large files)
136    Stream(PathBuf),
137}
138
139/// Pre-computed static response returned by the static files middleware
140#[derive(Debug)]
141pub struct StaticResponse {
142    pub header: ResponseHeader,
143    pub body: StaticResponseBody,
144}
145
146/// Tracks sendfile middleware state for a response
147#[derive(Debug, Default)]
148pub struct SendfileState {
149    /// Whether sendfile is engaged for this response
150    pub active: bool,
151    /// Optional path for logging/diagnostics
152    pub path: Option<String>,
153    /// In-memory body to serve (loaded once when active)
154    pub body: Option<Bytes>,
155    /// Flag indicating if body has been emitted downstream
156    pub served: bool,
157}
158
159impl SendfileState {
160    pub fn activate(&mut self, path: String, body: Bytes) {
161        self.active = true;
162        self.path = Some(path);
163        self.body = Some(body);
164        self.served = false;
165    }
166
167    pub fn reset(&mut self) {
168        self.active = false;
169        self.path = None;
170        self.body = None;
171        self.served = false;
172    }
173}
174
175/// Compression algorithm in use for the current response
176#[derive(Debug, Clone, Copy, PartialEq)]
177pub enum CompressionEncoding {
178    /// Brotli compression (preferred, better ratio)
179    Brotli,
180    /// Gzip compression (fallback, wider support)
181    Gzip,
182}
183
184/// Tracks compression middleware state across body chunks
185#[derive(Debug, Default)]
186pub enum CompressionState {
187    /// Compression disabled for this response
188    #[default]
189    Disabled,
190    /// Compression enabled and buffering body chunks
191    Pending {
192        buffer: Vec<u8>,
193        encoding: CompressionEncoding,
194    },
195    /// Compression already applied
196    Complete,
197}
198
199impl CompressionState {
200    pub fn enable(&mut self, encoding: CompressionEncoding) {
201        *self = CompressionState::Pending {
202            buffer: Vec::new(),
203            encoding,
204        };
205    }
206
207    pub fn is_enabled(&self) -> bool {
208        matches!(self, CompressionState::Pending { .. })
209    }
210}
211
212/// Middleware stack that composes all middlewares
213///
214/// This struct holds all middleware components and orchestrates their execution
215/// in the correct order for both request and response phases.
216pub struct MiddlewareStack {
217    pub trusted_ranges: Option<TrustedRangesMiddleware>,
218    pub static_files: Option<StaticFilesMiddleware>,
219    pub concurrency: Option<ConcurrencyMiddleware>,
220    pub rate_limit: Option<RateLimitMiddleware>,
221    pub circuit_breaker: Option<CircuitBreakerMiddleware>,
222    pub headers: HeadersMiddleware,
223    pub logging: Option<LoggingMiddleware>,
224    pub sendfile: Option<SendfileMiddleware>,
225    pub compression: Option<CompressionMiddleware>,
226}
227
228impl MiddlewareStack {
229    /// Create a new middleware stack from configuration
230    pub fn new(config: Arc<Config>) -> Self {
231        Self {
232            trusted_ranges: if config.trusted_ranges_file.is_some()
233                || config.client_ip_header.is_some()
234            {
235                TrustedRangesMiddleware::new(
236                    config.trusted_ranges_file.clone(),
237                    config.client_ip_header.clone(),
238                )
239                .ok()
240            } else {
241                None
242            },
243            static_files: if config.static_enabled {
244                Some(StaticFilesMiddleware::from_config(&config))
245            } else {
246                None
247            },
248            concurrency: if config.max_concurrent_requests > 0 {
249                Some(ConcurrencyMiddleware::new(
250                    true,
251                    config.max_concurrent_requests,
252                ))
253            } else {
254                None
255            },
256            rate_limit: if config.rate_limit_enabled {
257                Some(RateLimitMiddleware::new(
258                    true,
259                    config.rate_limit_requests_per_sec,
260                    config.rate_limit_burst_size,
261                ))
262            } else {
263                None
264            },
265            circuit_breaker: if config.circuit_breaker_enabled {
266                Some(CircuitBreakerMiddleware::new(
267                    true,
268                    config.circuit_breaker_failure_threshold,
269                    config.circuit_breaker_timeout_secs,
270                ))
271            } else {
272                None
273            },
274            headers: HeadersMiddleware::new(config.clone()),
275            logging: if config.log_requests {
276                Some(LoggingMiddleware::new())
277            } else {
278                None
279            },
280            sendfile: if config.x_sendfile_enabled {
281                Some(SendfileMiddleware::new())
282            } else {
283                None
284            },
285            compression: if config.gzip_compression_enabled {
286                Some(CompressionMiddleware::new())
287            } else {
288                None
289            },
290        }
291    }
292
293    /// Apply all request filters in order
294    ///
295    /// This is called from ProxyHttp::request_filter() to process the incoming request
296    /// before it's sent to the upstream server.
297    pub async fn apply_request_filters(
298        &self,
299        session: &mut Session,
300        ctx: &mut MiddlewareContext,
301    ) -> Result<()> {
302        // 0. Trusted ranges - normalize client IP and set trusted flag (must run FIRST)
303        if let Some(ref trusted_ranges) = self.trusted_ranges {
304            trusted_ranges.request_filter(session, ctx).await?;
305        }
306
307        // 1. Static files - serve directly if path matches (skip all other middleware + proxy)
308        if let Some(ref static_files) = self.static_files {
309            static_files.request_filter(session, ctx).await?;
310
311            // If static response was set, short-circuit the proxy
312            if ctx.static_response.is_some() {
313                return Ok(());
314            }
315        }
316
317        // 2. Concurrency limiting - limit total concurrent requests
318        if let Some(ref concurrency) = self.concurrency {
319            concurrency.request_filter(session, ctx).await?;
320        }
321
322        // 3. Rate limiting - check limits (fail fast)
323        if let Some(ref rate_limit) = self.rate_limit {
324            rate_limit.request_filter(session, ctx).await?;
325        }
326
327        // 4. Circuit breaker - check if upstream is available
328        if let Some(ref circuit_breaker) = self.circuit_breaker {
329            circuit_breaker.request_filter(session, ctx).await?;
330        }
331
332        // 5. Headers middleware - add X-Forwarded-* headers
333        self.headers.request_filter(session, ctx).await?;
334
335        // 6. Logging middleware - record request start
336        if let Some(ref logging) = self.logging {
337            logging.request_filter(session, ctx).await?;
338        }
339
340        Ok(())
341    }
342
343    /// Apply all response filters in order
344    ///
345    /// This is called from ProxyHttp::response_filter() to process the upstream response
346    /// before it's sent to the client.
347    pub async fn apply_response_filters(
348        &self,
349        session: &mut Session,
350        upstream_response: &mut ResponseHeader,
351        ctx: &mut MiddlewareContext,
352    ) -> Result<()> {
353        // 1. Circuit breaker - record upstream response (success/failure)
354        if let Some(ref circuit_breaker) = self.circuit_breaker {
355            circuit_breaker
356                .response_filter(session, upstream_response, ctx)
357                .await?;
358        }
359
360        // 2. Sendfile middleware - handle X-Sendfile headers
361        if let Some(ref sendfile) = self.sendfile {
362            sendfile
363                .response_filter(session, upstream_response, ctx)
364                .await?;
365        }
366
367        // 2. Compression middleware - compress response if applicable
368        if let Some(ref compression) = self.compression {
369            compression
370                .response_filter(session, upstream_response, ctx)
371                .await?;
372        }
373
374        // 3. Logging middleware - log response (must be last to capture all metadata)
375        if let Some(ref logging) = self.logging {
376            logging
377                .response_filter(session, upstream_response, ctx)
378                .await?;
379        }
380
381        Ok(())
382    }
383}
384
385/// Trait for middleware components
386///
387/// Each middleware implements this trait to define its behavior during
388/// request and response phases. Both methods are optional - middlewares
389/// can implement only the phase they care about.
390#[async_trait]
391pub trait Middleware: Send + Sync {
392    /// Process request before sending to upstream
393    async fn request_filter(
394        &self,
395        _session: &mut Session,
396        _ctx: &mut MiddlewareContext,
397    ) -> Result<()> {
398        Ok(())
399    }
400
401    /// Process response before sending to client
402    async fn response_filter(
403        &self,
404        _session: &mut Session,
405        _upstream_response: &mut ResponseHeader,
406        _ctx: &mut MiddlewareContext,
407    ) -> Result<()> {
408        Ok(())
409    }
410}