sentinel_proxy/proxy/
context.rs

1//! Request context for the proxy request lifecycle.
2//!
3//! The `RequestContext` struct maintains state throughout a single request,
4//! including timing, routing decisions, and metadata for logging.
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use sentinel_config::{BodyStreamingMode, Config, RouteConfig, ServiceType};
10
11use crate::websocket::WebSocketHandler;
12
13/// Rate limit header information for response headers
14#[derive(Debug, Clone)]
15pub struct RateLimitHeaderInfo {
16    /// Maximum requests allowed per window
17    pub limit: u32,
18    /// Remaining requests in current window
19    pub remaining: u32,
20    /// Unix timestamp (seconds) when the window resets
21    pub reset_at: u64,
22}
23
24/// Request context maintained throughout the request lifecycle.
25///
26/// This struct uses a hybrid approach:
27/// - Immutable fields (start_time) are private with getters
28/// - Mutable fields are public(crate) for efficient access within the proxy module
29pub struct RequestContext {
30    /// Request start time (immutable after creation)
31    start_time: Instant,
32
33    // === Tracing ===
34    /// Unique trace ID for request tracing (also used as correlation_id)
35    pub(crate) trace_id: String,
36
37    // === Global config (cached once per request) ===
38    /// Cached global configuration snapshot for this request
39    pub(crate) config: Option<Arc<Config>>,
40
41    // === Routing ===
42    /// Selected route ID
43    pub(crate) route_id: Option<String>,
44    /// Cached route configuration (avoids duplicate route matching)
45    pub(crate) route_config: Option<Arc<RouteConfig>>,
46    /// Selected upstream
47    pub(crate) upstream: Option<String>,
48    /// Number of upstream attempts
49    pub(crate) upstream_attempts: u32,
50
51    // === Request metadata (cached for logging) ===
52    /// HTTP method
53    pub(crate) method: String,
54    /// Request path
55    pub(crate) path: String,
56    /// Query string
57    pub(crate) query: Option<String>,
58
59    // === Client info ===
60    /// Client IP address
61    pub(crate) client_ip: String,
62    /// User-Agent header
63    pub(crate) user_agent: Option<String>,
64    /// Referer header
65    pub(crate) referer: Option<String>,
66    /// Host header
67    pub(crate) host: Option<String>,
68
69    // === Body tracking ===
70    /// Request body bytes received
71    pub(crate) request_body_bytes: u64,
72    /// Response body bytes (set during response)
73    pub(crate) response_bytes: u64,
74
75    // === Connection tracking ===
76    /// Whether the upstream connection was reused
77    pub(crate) connection_reused: bool,
78    /// Whether this request is a WebSocket upgrade
79    pub(crate) is_websocket_upgrade: bool,
80
81    // === WebSocket Inspection ===
82    /// Whether WebSocket frame inspection is enabled for this connection
83    pub(crate) websocket_inspection_enabled: bool,
84    /// Whether to skip inspection (e.g., due to compression negotiation)
85    pub(crate) websocket_skip_inspection: bool,
86    /// Agent IDs for WebSocket frame inspection
87    pub(crate) websocket_inspection_agents: Vec<String>,
88    /// WebSocket frame handler (created after 101 upgrade)
89    pub(crate) websocket_handler: Option<Arc<WebSocketHandler>>,
90
91    // === Caching ===
92    /// Whether this request is eligible for caching
93    pub(crate) cache_eligible: bool,
94
95    // === Body Inspection ===
96    /// Whether body inspection is enabled for this request
97    pub(crate) body_inspection_enabled: bool,
98    /// Bytes already sent to agent for inspection
99    pub(crate) body_bytes_inspected: u64,
100    /// Accumulated body buffer for agent inspection
101    pub(crate) body_buffer: Vec<u8>,
102    /// Agent IDs to use for body inspection
103    pub(crate) body_inspection_agents: Vec<String>,
104
105    // === Body Decompression ===
106    /// Whether decompression is enabled for body inspection
107    pub(crate) decompression_enabled: bool,
108    /// Content-Encoding of the request body (if compressed)
109    pub(crate) body_content_encoding: Option<String>,
110    /// Maximum decompression ratio allowed
111    pub(crate) max_decompression_ratio: f64,
112    /// Maximum decompressed size allowed
113    pub(crate) max_decompression_bytes: usize,
114    /// Whether decompression was performed
115    pub(crate) body_was_decompressed: bool,
116
117    // === Rate Limiting ===
118    /// Rate limit info for response headers (set during request_filter)
119    pub(crate) rate_limit_info: Option<RateLimitHeaderInfo>,
120
121    // === GeoIP Filtering ===
122    /// Country code from GeoIP lookup (ISO 3166-1 alpha-2)
123    pub(crate) geo_country_code: Option<String>,
124    /// Whether a geo lookup was performed for this request
125    pub(crate) geo_lookup_performed: bool,
126
127    // === Body Streaming ===
128    /// Body streaming mode for request body inspection
129    pub(crate) request_body_streaming_mode: BodyStreamingMode,
130    /// Current chunk index for request body streaming
131    pub(crate) request_body_chunk_index: u32,
132    /// Whether agent needs more data (streaming mode)
133    pub(crate) agent_needs_more: bool,
134    /// Body streaming mode for response body inspection
135    pub(crate) response_body_streaming_mode: BodyStreamingMode,
136    /// Current chunk index for response body streaming
137    pub(crate) response_body_chunk_index: u32,
138    /// Response body bytes inspected
139    pub(crate) response_body_bytes_inspected: u64,
140    /// Response body inspection enabled
141    pub(crate) response_body_inspection_enabled: bool,
142    /// Agent IDs for response body inspection
143    pub(crate) response_body_inspection_agents: Vec<String>,
144}
145
146impl RequestContext {
147    /// Create a new empty request context with the current timestamp.
148    pub fn new() -> Self {
149        Self {
150            start_time: Instant::now(),
151            trace_id: String::new(),
152            config: None,
153            route_id: None,
154            route_config: None,
155            upstream: None,
156            upstream_attempts: 0,
157            method: String::new(),
158            path: String::new(),
159            query: None,
160            client_ip: String::new(),
161            user_agent: None,
162            referer: None,
163            host: None,
164            request_body_bytes: 0,
165            response_bytes: 0,
166            connection_reused: false,
167            is_websocket_upgrade: false,
168            websocket_inspection_enabled: false,
169            websocket_skip_inspection: false,
170            websocket_inspection_agents: Vec::new(),
171            websocket_handler: None,
172            cache_eligible: false,
173            body_inspection_enabled: false,
174            body_bytes_inspected: 0,
175            body_buffer: Vec::new(),
176            body_inspection_agents: Vec::new(),
177            decompression_enabled: false,
178            body_content_encoding: None,
179            max_decompression_ratio: 100.0,
180            max_decompression_bytes: 10 * 1024 * 1024, // 10MB
181            body_was_decompressed: false,
182            rate_limit_info: None,
183            geo_country_code: None,
184            geo_lookup_performed: false,
185            request_body_streaming_mode: BodyStreamingMode::Buffer,
186            request_body_chunk_index: 0,
187            agent_needs_more: false,
188            response_body_streaming_mode: BodyStreamingMode::Buffer,
189            response_body_chunk_index: 0,
190            response_body_bytes_inspected: 0,
191            response_body_inspection_enabled: false,
192            response_body_inspection_agents: Vec::new(),
193        }
194    }
195
196    // === Immutable field accessors ===
197
198    /// Get the request start time.
199    #[inline]
200    pub fn start_time(&self) -> Instant {
201        self.start_time
202    }
203
204    /// Get elapsed duration since request start.
205    #[inline]
206    pub fn elapsed(&self) -> std::time::Duration {
207        self.start_time.elapsed()
208    }
209
210    // === Read-only accessors ===
211
212    /// Get trace_id (alias for backwards compatibility with correlation_id usage).
213    #[inline]
214    pub fn correlation_id(&self) -> &str {
215        &self.trace_id
216    }
217
218    /// Get the trace ID.
219    #[inline]
220    pub fn trace_id(&self) -> &str {
221        &self.trace_id
222    }
223
224    /// Get the route ID, if set.
225    #[inline]
226    pub fn route_id(&self) -> Option<&str> {
227        self.route_id.as_deref()
228    }
229
230    /// Get the upstream ID, if set.
231    #[inline]
232    pub fn upstream(&self) -> Option<&str> {
233        self.upstream.as_deref()
234    }
235
236    /// Get the cached route configuration, if set.
237    #[inline]
238    pub fn route_config(&self) -> Option<&Arc<RouteConfig>> {
239        self.route_config.as_ref()
240    }
241
242    /// Get the cached global configuration, if set.
243    #[inline]
244    pub fn global_config(&self) -> Option<&Arc<Config>> {
245        self.config.as_ref()
246    }
247
248    /// Get the service type from cached route config.
249    #[inline]
250    pub fn service_type(&self) -> Option<ServiceType> {
251        self.route_config.as_ref().map(|c| c.service_type.clone())
252    }
253
254    /// Get the number of upstream attempts.
255    #[inline]
256    pub fn upstream_attempts(&self) -> u32 {
257        self.upstream_attempts
258    }
259
260    /// Get the HTTP method.
261    #[inline]
262    pub fn method(&self) -> &str {
263        &self.method
264    }
265
266    /// Get the request path.
267    #[inline]
268    pub fn path(&self) -> &str {
269        &self.path
270    }
271
272    /// Get the query string, if present.
273    #[inline]
274    pub fn query(&self) -> Option<&str> {
275        self.query.as_deref()
276    }
277
278    /// Get the client IP address.
279    #[inline]
280    pub fn client_ip(&self) -> &str {
281        &self.client_ip
282    }
283
284    /// Get the User-Agent header, if present.
285    #[inline]
286    pub fn user_agent(&self) -> Option<&str> {
287        self.user_agent.as_deref()
288    }
289
290    /// Get the Referer header, if present.
291    #[inline]
292    pub fn referer(&self) -> Option<&str> {
293        self.referer.as_deref()
294    }
295
296    /// Get the Host header, if present.
297    #[inline]
298    pub fn host(&self) -> Option<&str> {
299        self.host.as_deref()
300    }
301
302    /// Get the response body size in bytes.
303    #[inline]
304    pub fn response_bytes(&self) -> u64 {
305        self.response_bytes
306    }
307
308    /// Get the GeoIP country code, if determined.
309    #[inline]
310    pub fn geo_country_code(&self) -> Option<&str> {
311        self.geo_country_code.as_deref()
312    }
313
314    /// Check if a geo lookup was performed for this request.
315    #[inline]
316    pub fn geo_lookup_performed(&self) -> bool {
317        self.geo_lookup_performed
318    }
319
320    // === Mutation helpers ===
321
322    /// Set the trace ID.
323    #[inline]
324    pub fn set_trace_id(&mut self, trace_id: impl Into<String>) {
325        self.trace_id = trace_id.into();
326    }
327
328    /// Set the route ID.
329    #[inline]
330    pub fn set_route_id(&mut self, route_id: impl Into<String>) {
331        self.route_id = Some(route_id.into());
332    }
333
334    /// Set the upstream ID.
335    #[inline]
336    pub fn set_upstream(&mut self, upstream: impl Into<String>) {
337        self.upstream = Some(upstream.into());
338    }
339
340    /// Increment upstream attempt counter.
341    #[inline]
342    pub fn inc_upstream_attempts(&mut self) {
343        self.upstream_attempts += 1;
344    }
345
346    /// Set response bytes.
347    #[inline]
348    pub fn set_response_bytes(&mut self, bytes: u64) {
349        self.response_bytes = bytes;
350    }
351}
352
353impl Default for RequestContext {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359// ============================================================================
360// Tests
361// ============================================================================
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_rate_limit_header_info() {
369        let info = RateLimitHeaderInfo {
370            limit: 100,
371            remaining: 42,
372            reset_at: 1704067200,
373        };
374
375        assert_eq!(info.limit, 100);
376        assert_eq!(info.remaining, 42);
377        assert_eq!(info.reset_at, 1704067200);
378    }
379
380    #[test]
381    fn test_request_context_default() {
382        let ctx = RequestContext::new();
383
384        assert!(ctx.trace_id.is_empty());
385        assert!(ctx.rate_limit_info.is_none());
386        assert!(ctx.route_id.is_none());
387        assert!(ctx.config.is_none());
388    }
389
390    #[test]
391    fn test_request_context_rate_limit_info() {
392        let mut ctx = RequestContext::new();
393
394        // Initially no rate limit info
395        assert!(ctx.rate_limit_info.is_none());
396
397        // Set rate limit info
398        ctx.rate_limit_info = Some(RateLimitHeaderInfo {
399            limit: 50,
400            remaining: 25,
401            reset_at: 1704067300,
402        });
403
404        assert!(ctx.rate_limit_info.is_some());
405        let info = ctx.rate_limit_info.as_ref().unwrap();
406        assert_eq!(info.limit, 50);
407        assert_eq!(info.remaining, 25);
408        assert_eq!(info.reset_at, 1704067300);
409    }
410
411    #[test]
412    fn test_request_context_elapsed() {
413        let ctx = RequestContext::new();
414
415        // Elapsed time should be very small (less than 1 second)
416        let elapsed = ctx.elapsed();
417        assert!(elapsed.as_secs() < 1);
418    }
419
420    #[test]
421    fn test_request_context_setters() {
422        let mut ctx = RequestContext::new();
423
424        ctx.set_trace_id("trace-123");
425        assert_eq!(ctx.trace_id(), "trace-123");
426        assert_eq!(ctx.correlation_id(), "trace-123");
427
428        ctx.set_route_id("my-route");
429        assert_eq!(ctx.route_id(), Some("my-route"));
430
431        ctx.set_upstream("backend-pool");
432        assert_eq!(ctx.upstream(), Some("backend-pool"));
433
434        ctx.inc_upstream_attempts();
435        ctx.inc_upstream_attempts();
436        assert_eq!(ctx.upstream_attempts(), 2);
437
438        ctx.set_response_bytes(1024);
439        assert_eq!(ctx.response_bytes(), 1024);
440    }
441}