sentinel_proxy/proxy/
context.rs

1//! Request context for the proxy request lifecycle.
2//!
3//! The `RequestContext` struct maintains state throughout a single request,
4//! including timing, routing decisions, and metadata for logging.
5
6use std::sync::Arc;
7use std::time::Instant;
8
9use sentinel_config::{BodyStreamingMode, Config, RouteConfig, ServiceType};
10
11use crate::inference::StreamingTokenCounter;
12use crate::websocket::WebSocketHandler;
13
14/// Reason why fallback routing was triggered
15#[derive(Debug, Clone)]
16pub enum FallbackReason {
17    /// Primary upstream health check failed
18    HealthCheckFailed,
19    /// Token budget exhausted for the request
20    BudgetExhausted,
21    /// Response latency exceeded threshold
22    LatencyThreshold { observed_ms: u64, threshold_ms: u64 },
23    /// Upstream returned an error code that triggers fallback
24    ErrorCode(u16),
25    /// Connection to upstream failed
26    ConnectionError(String),
27}
28
29impl std::fmt::Display for FallbackReason {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            FallbackReason::HealthCheckFailed => write!(f, "health_check_failed"),
33            FallbackReason::BudgetExhausted => write!(f, "budget_exhausted"),
34            FallbackReason::LatencyThreshold {
35                observed_ms,
36                threshold_ms,
37            } => write!(f, "latency_threshold_{}ms_exceeded_{}ms", observed_ms, threshold_ms),
38            FallbackReason::ErrorCode(code) => write!(f, "error_code_{}", code),
39            FallbackReason::ConnectionError(msg) => write!(f, "connection_error_{}", msg),
40        }
41    }
42}
43
44/// Rate limit header information for response headers
45#[derive(Debug, Clone)]
46pub struct RateLimitHeaderInfo {
47    /// Maximum requests allowed per window
48    pub limit: u32,
49    /// Remaining requests in current window
50    pub remaining: u32,
51    /// Unix timestamp (seconds) when the window resets
52    pub reset_at: u64,
53}
54
55/// Request context maintained throughout the request lifecycle.
56///
57/// This struct uses a hybrid approach:
58/// - Immutable fields (start_time) are private with getters
59/// - Mutable fields are public(crate) for efficient access within the proxy module
60pub struct RequestContext {
61    /// Request start time (immutable after creation)
62    start_time: Instant,
63
64    // === Tracing ===
65    /// Unique trace ID for request tracing (also used as correlation_id)
66    pub(crate) trace_id: String,
67
68    // === Global config (cached once per request) ===
69    /// Cached global configuration snapshot for this request
70    pub(crate) config: Option<Arc<Config>>,
71
72    // === Routing ===
73    /// Selected route ID
74    pub(crate) route_id: Option<String>,
75    /// Cached route configuration (avoids duplicate route matching)
76    pub(crate) route_config: Option<Arc<RouteConfig>>,
77    /// Selected upstream pool ID
78    pub(crate) upstream: Option<String>,
79    /// Selected upstream peer address (IP:port) for feedback reporting
80    pub(crate) selected_upstream_address: Option<String>,
81    /// Number of upstream attempts
82    pub(crate) upstream_attempts: u32,
83
84    // === Scope (for namespaced configurations) ===
85    /// Namespace for this request (if routed to a namespace scope)
86    pub(crate) namespace: Option<String>,
87    /// Service for this request (if routed to a service scope)
88    pub(crate) service: Option<String>,
89
90    // === Request metadata (cached for logging) ===
91    /// HTTP method
92    pub(crate) method: String,
93    /// Request path
94    pub(crate) path: String,
95    /// Query string
96    pub(crate) query: Option<String>,
97
98    // === Client info ===
99    /// Client IP address
100    pub(crate) client_ip: String,
101    /// User-Agent header
102    pub(crate) user_agent: Option<String>,
103    /// Referer header
104    pub(crate) referer: Option<String>,
105    /// Host header
106    pub(crate) host: Option<String>,
107
108    // === Body tracking ===
109    /// Request body bytes received
110    pub(crate) request_body_bytes: u64,
111    /// Response body bytes (set during response)
112    pub(crate) response_bytes: u64,
113
114    // === Connection tracking ===
115    /// Whether the upstream connection was reused
116    pub(crate) connection_reused: bool,
117    /// Whether this request is a WebSocket upgrade
118    pub(crate) is_websocket_upgrade: bool,
119
120    // === WebSocket Inspection ===
121    /// Whether WebSocket frame inspection is enabled for this connection
122    pub(crate) websocket_inspection_enabled: bool,
123    /// Whether to skip inspection (e.g., due to compression negotiation)
124    pub(crate) websocket_skip_inspection: bool,
125    /// Agent IDs for WebSocket frame inspection
126    pub(crate) websocket_inspection_agents: Vec<String>,
127    /// WebSocket frame handler (created after 101 upgrade)
128    pub(crate) websocket_handler: Option<Arc<WebSocketHandler>>,
129
130    // === Caching ===
131    /// Whether this request is eligible for caching
132    pub(crate) cache_eligible: bool,
133
134    // === Body Inspection ===
135    /// Whether body inspection is enabled for this request
136    pub(crate) body_inspection_enabled: bool,
137    /// Bytes already sent to agent for inspection
138    pub(crate) body_bytes_inspected: u64,
139    /// Accumulated body buffer for agent inspection
140    pub(crate) body_buffer: Vec<u8>,
141    /// Agent IDs to use for body inspection
142    pub(crate) body_inspection_agents: Vec<String>,
143
144    // === Body Decompression ===
145    /// Whether decompression is enabled for body inspection
146    pub(crate) decompression_enabled: bool,
147    /// Content-Encoding of the request body (if compressed)
148    pub(crate) body_content_encoding: Option<String>,
149    /// Maximum decompression ratio allowed
150    pub(crate) max_decompression_ratio: f64,
151    /// Maximum decompressed size allowed
152    pub(crate) max_decompression_bytes: usize,
153    /// Whether decompression was performed
154    pub(crate) body_was_decompressed: bool,
155
156    // === Rate Limiting ===
157    /// Rate limit info for response headers (set during request_filter)
158    pub(crate) rate_limit_info: Option<RateLimitHeaderInfo>,
159
160    // === GeoIP Filtering ===
161    /// Country code from GeoIP lookup (ISO 3166-1 alpha-2)
162    pub(crate) geo_country_code: Option<String>,
163    /// Whether a geo lookup was performed for this request
164    pub(crate) geo_lookup_performed: bool,
165
166    // === Body Streaming ===
167    /// Body streaming mode for request body inspection
168    pub(crate) request_body_streaming_mode: BodyStreamingMode,
169    /// Current chunk index for request body streaming
170    pub(crate) request_body_chunk_index: u32,
171    /// Whether agent needs more data (streaming mode)
172    pub(crate) agent_needs_more: bool,
173    /// Body streaming mode for response body inspection
174    pub(crate) response_body_streaming_mode: BodyStreamingMode,
175    /// Current chunk index for response body streaming
176    pub(crate) response_body_chunk_index: u32,
177    /// Response body bytes inspected
178    pub(crate) response_body_bytes_inspected: u64,
179    /// Response body inspection enabled
180    pub(crate) response_body_inspection_enabled: bool,
181    /// Agent IDs for response body inspection
182    pub(crate) response_body_inspection_agents: Vec<String>,
183
184    // === OpenTelemetry Tracing ===
185    /// OpenTelemetry request span (if tracing enabled)
186    pub(crate) otel_span: Option<crate::otel::RequestSpan>,
187    /// W3C trace context parsed from incoming request
188    pub(crate) trace_context: Option<crate::otel::TraceContext>,
189
190    // === Inference Rate Limiting ===
191    /// Whether inference rate limiting is enabled for this route
192    pub(crate) inference_rate_limit_enabled: bool,
193    /// Estimated tokens for this request (used for rate limiting)
194    pub(crate) inference_estimated_tokens: u64,
195    /// Rate limit key used (client IP, API key, etc.)
196    pub(crate) inference_rate_limit_key: Option<String>,
197    /// Model name detected from request
198    pub(crate) inference_model: Option<String>,
199    /// Provider override from model-based routing (for cross-provider routing)
200    pub(crate) inference_provider_override: Option<sentinel_config::InferenceProvider>,
201    /// Whether model-based routing was used to select the upstream
202    pub(crate) model_routing_used: bool,
203    /// Actual tokens from response (filled in after response)
204    pub(crate) inference_actual_tokens: Option<u64>,
205
206    // === Token Budget Tracking ===
207    /// Whether budget tracking is enabled for this route
208    pub(crate) inference_budget_enabled: bool,
209    /// Budget remaining after this request (set after response)
210    pub(crate) inference_budget_remaining: Option<i64>,
211    /// Period reset timestamp (Unix seconds)
212    pub(crate) inference_budget_period_reset: Option<u64>,
213    /// Whether budget was exhausted (429 sent)
214    pub(crate) inference_budget_exhausted: bool,
215
216    // === Cost Attribution ===
217    /// Whether cost attribution is enabled for this route
218    pub(crate) inference_cost_enabled: bool,
219    /// Calculated cost for this request (set after response)
220    pub(crate) inference_request_cost: Option<f64>,
221    /// Input tokens for cost calculation
222    pub(crate) inference_input_tokens: u64,
223    /// Output tokens for cost calculation
224    pub(crate) inference_output_tokens: u64,
225
226    // === Streaming Token Counting ===
227    /// Whether this is a streaming (SSE) response
228    pub(crate) inference_streaming_response: bool,
229    /// Streaming token counter for SSE responses
230    pub(crate) inference_streaming_counter: Option<StreamingTokenCounter>,
231
232    // === Fallback Routing ===
233    /// Current fallback attempt number (0 = primary, 1+ = fallback)
234    pub(crate) fallback_attempt: u32,
235    /// List of upstream IDs that have been tried
236    pub(crate) tried_upstreams: Vec<String>,
237    /// Reason for triggering fallback (if fallback was used)
238    pub(crate) fallback_reason: Option<FallbackReason>,
239    /// Original upstream ID before fallback (primary)
240    pub(crate) original_upstream: Option<String>,
241    /// Model mapping applied: (original_model, mapped_model)
242    pub(crate) model_mapping_applied: Option<(String, String)>,
243    /// Whether fallback should be retried after response
244    pub(crate) should_retry_with_fallback: bool,
245
246    // === Semantic Guardrails ===
247    /// Whether guardrails are enabled for this route
248    pub(crate) guardrails_enabled: bool,
249    /// Prompt injection detected but allowed (add warning header)
250    pub(crate) guardrail_warning: bool,
251    /// Categories of prompt injection detected (for logging)
252    pub(crate) guardrail_detection_categories: Vec<String>,
253    /// PII categories detected in response (for logging)
254    pub(crate) pii_detection_categories: Vec<String>,
255
256    // === Shadow Traffic ===
257    /// Pending shadow request info (stored for deferred execution after body buffering)
258    pub(crate) shadow_pending: Option<ShadowPendingRequest>,
259    /// Whether shadow request was sent for this request
260    pub(crate) shadow_sent: bool,
261}
262
263/// Pending shadow request information stored in context for deferred execution
264#[derive(Clone)]
265pub struct ShadowPendingRequest {
266    /// Cloned request headers for shadow
267    pub headers: pingora::http::RequestHeader,
268    /// Shadow manager (wrapped in Arc for Clone)
269    pub manager: std::sync::Arc<crate::shadow::ShadowManager>,
270    /// Request context for shadow (client IP, path, method, etc.)
271    pub request_ctx: crate::upstream::RequestContext,
272    /// Whether body should be included
273    pub include_body: bool,
274}
275
276impl RequestContext {
277    /// Create a new empty request context with the current timestamp.
278    pub fn new() -> Self {
279        Self {
280            start_time: Instant::now(),
281            trace_id: String::new(),
282            config: None,
283            route_id: None,
284            route_config: None,
285            upstream: None,
286            selected_upstream_address: None,
287            upstream_attempts: 0,
288            namespace: None,
289            service: None,
290            method: String::new(),
291            path: String::new(),
292            query: None,
293            client_ip: String::new(),
294            user_agent: None,
295            referer: None,
296            host: None,
297            request_body_bytes: 0,
298            response_bytes: 0,
299            connection_reused: false,
300            is_websocket_upgrade: false,
301            websocket_inspection_enabled: false,
302            websocket_skip_inspection: false,
303            websocket_inspection_agents: Vec::new(),
304            websocket_handler: None,
305            cache_eligible: false,
306            body_inspection_enabled: false,
307            body_bytes_inspected: 0,
308            body_buffer: Vec::new(),
309            body_inspection_agents: Vec::new(),
310            decompression_enabled: false,
311            body_content_encoding: None,
312            max_decompression_ratio: 100.0,
313            max_decompression_bytes: 10 * 1024 * 1024, // 10MB
314            body_was_decompressed: false,
315            rate_limit_info: None,
316            geo_country_code: None,
317            geo_lookup_performed: false,
318            request_body_streaming_mode: BodyStreamingMode::Buffer,
319            request_body_chunk_index: 0,
320            agent_needs_more: false,
321            response_body_streaming_mode: BodyStreamingMode::Buffer,
322            response_body_chunk_index: 0,
323            response_body_bytes_inspected: 0,
324            response_body_inspection_enabled: false,
325            response_body_inspection_agents: Vec::new(),
326            otel_span: None,
327            trace_context: None,
328            inference_rate_limit_enabled: false,
329            inference_estimated_tokens: 0,
330            inference_rate_limit_key: None,
331            inference_model: None,
332            inference_provider_override: None,
333            model_routing_used: false,
334            inference_actual_tokens: None,
335            inference_budget_enabled: false,
336            inference_budget_remaining: None,
337            inference_budget_period_reset: None,
338            inference_budget_exhausted: false,
339            inference_cost_enabled: false,
340            inference_request_cost: None,
341            inference_input_tokens: 0,
342            inference_output_tokens: 0,
343            inference_streaming_response: false,
344            inference_streaming_counter: None,
345            fallback_attempt: 0,
346            tried_upstreams: Vec::new(),
347            fallback_reason: None,
348            original_upstream: None,
349            model_mapping_applied: None,
350            should_retry_with_fallback: false,
351            guardrails_enabled: false,
352            guardrail_warning: false,
353            guardrail_detection_categories: Vec::new(),
354            pii_detection_categories: Vec::new(),
355            shadow_pending: None,
356            shadow_sent: false,
357        }
358    }
359
360    // === Immutable field accessors ===
361
362    /// Get the request start time.
363    #[inline]
364    pub fn start_time(&self) -> Instant {
365        self.start_time
366    }
367
368    /// Get elapsed duration since request start.
369    #[inline]
370    pub fn elapsed(&self) -> std::time::Duration {
371        self.start_time.elapsed()
372    }
373
374    // === Read-only accessors ===
375
376    /// Get trace_id (alias for backwards compatibility with correlation_id usage).
377    #[inline]
378    pub fn correlation_id(&self) -> &str {
379        &self.trace_id
380    }
381
382    /// Get the trace ID.
383    #[inline]
384    pub fn trace_id(&self) -> &str {
385        &self.trace_id
386    }
387
388    /// Get the route ID, if set.
389    #[inline]
390    pub fn route_id(&self) -> Option<&str> {
391        self.route_id.as_deref()
392    }
393
394    /// Get the upstream ID, if set.
395    #[inline]
396    pub fn upstream(&self) -> Option<&str> {
397        self.upstream.as_deref()
398    }
399
400    /// Get the selected upstream peer address (IP:port), if set.
401    #[inline]
402    pub fn selected_upstream_address(&self) -> Option<&str> {
403        self.selected_upstream_address.as_deref()
404    }
405
406    /// Get the cached route configuration, if set.
407    #[inline]
408    pub fn route_config(&self) -> Option<&Arc<RouteConfig>> {
409        self.route_config.as_ref()
410    }
411
412    /// Get the cached global configuration, if set.
413    #[inline]
414    pub fn global_config(&self) -> Option<&Arc<Config>> {
415        self.config.as_ref()
416    }
417
418    /// Get the service type from cached route config.
419    #[inline]
420    pub fn service_type(&self) -> Option<ServiceType> {
421        self.route_config.as_ref().map(|c| c.service_type.clone())
422    }
423
424    /// Get the number of upstream attempts.
425    #[inline]
426    pub fn upstream_attempts(&self) -> u32 {
427        self.upstream_attempts
428    }
429
430    /// Get the HTTP method.
431    #[inline]
432    pub fn method(&self) -> &str {
433        &self.method
434    }
435
436    /// Get the request path.
437    #[inline]
438    pub fn path(&self) -> &str {
439        &self.path
440    }
441
442    /// Get the query string, if present.
443    #[inline]
444    pub fn query(&self) -> Option<&str> {
445        self.query.as_deref()
446    }
447
448    /// Get the client IP address.
449    #[inline]
450    pub fn client_ip(&self) -> &str {
451        &self.client_ip
452    }
453
454    /// Get the User-Agent header, if present.
455    #[inline]
456    pub fn user_agent(&self) -> Option<&str> {
457        self.user_agent.as_deref()
458    }
459
460    /// Get the Referer header, if present.
461    #[inline]
462    pub fn referer(&self) -> Option<&str> {
463        self.referer.as_deref()
464    }
465
466    /// Get the Host header, if present.
467    #[inline]
468    pub fn host(&self) -> Option<&str> {
469        self.host.as_deref()
470    }
471
472    /// Get the response body size in bytes.
473    #[inline]
474    pub fn response_bytes(&self) -> u64 {
475        self.response_bytes
476    }
477
478    /// Get the GeoIP country code, if determined.
479    #[inline]
480    pub fn geo_country_code(&self) -> Option<&str> {
481        self.geo_country_code.as_deref()
482    }
483
484    /// Check if a geo lookup was performed for this request.
485    #[inline]
486    pub fn geo_lookup_performed(&self) -> bool {
487        self.geo_lookup_performed
488    }
489
490    /// Get traceparent header value for distributed tracing.
491    ///
492    /// Returns the W3C Trace Context traceparent header value if tracing is enabled.
493    /// Format: `{version}-{trace-id}-{span-id}-{trace-flags}`
494    #[inline]
495    pub fn traceparent(&self) -> Option<String> {
496        self.otel_span.as_ref().map(|span| {
497            let sampled = self.trace_context.as_ref().map(|c| c.sampled).unwrap_or(true);
498            crate::otel::create_traceparent(&span.trace_id, &span.span_id, sampled)
499        })
500    }
501
502    // === Mutation helpers ===
503
504    /// Set the trace ID.
505    #[inline]
506    pub fn set_trace_id(&mut self, trace_id: impl Into<String>) {
507        self.trace_id = trace_id.into();
508    }
509
510    /// Set the route ID.
511    #[inline]
512    pub fn set_route_id(&mut self, route_id: impl Into<String>) {
513        self.route_id = Some(route_id.into());
514    }
515
516    /// Set the upstream ID.
517    #[inline]
518    pub fn set_upstream(&mut self, upstream: impl Into<String>) {
519        self.upstream = Some(upstream.into());
520    }
521
522    /// Set the selected upstream peer address (IP:port).
523    #[inline]
524    pub fn set_selected_upstream_address(&mut self, address: impl Into<String>) {
525        self.selected_upstream_address = Some(address.into());
526    }
527
528    /// Increment upstream attempt counter.
529    #[inline]
530    pub fn inc_upstream_attempts(&mut self) {
531        self.upstream_attempts += 1;
532    }
533
534    /// Set response bytes.
535    #[inline]
536    pub fn set_response_bytes(&mut self, bytes: u64) {
537        self.response_bytes = bytes;
538    }
539
540    // === Fallback accessors ===
541
542    /// Get the current fallback attempt number (0 = primary).
543    #[inline]
544    pub fn fallback_attempt(&self) -> u32 {
545        self.fallback_attempt
546    }
547
548    /// Get the list of upstreams that have been tried.
549    #[inline]
550    pub fn tried_upstreams(&self) -> &[String] {
551        &self.tried_upstreams
552    }
553
554    /// Get the fallback reason, if fallback was triggered.
555    #[inline]
556    pub fn fallback_reason(&self) -> Option<&FallbackReason> {
557        self.fallback_reason.as_ref()
558    }
559
560    /// Get the original upstream ID (before fallback).
561    #[inline]
562    pub fn original_upstream(&self) -> Option<&str> {
563        self.original_upstream.as_deref()
564    }
565
566    /// Get the model mapping that was applied: (original, mapped).
567    #[inline]
568    pub fn model_mapping_applied(&self) -> Option<&(String, String)> {
569        self.model_mapping_applied.as_ref()
570    }
571
572    /// Check if fallback was used for this request.
573    #[inline]
574    pub fn used_fallback(&self) -> bool {
575        self.fallback_attempt > 0
576    }
577
578    /// Record that a fallback attempt is being made.
579    #[inline]
580    pub fn record_fallback(&mut self, reason: FallbackReason, new_upstream: &str) {
581        if self.fallback_attempt == 0 {
582            // First fallback - save original upstream
583            self.original_upstream = self.upstream.clone();
584        }
585        self.fallback_attempt += 1;
586        self.fallback_reason = Some(reason);
587        if let Some(current) = &self.upstream {
588            self.tried_upstreams.push(current.clone());
589        }
590        self.upstream = Some(new_upstream.to_string());
591    }
592
593    /// Record model mapping applied during fallback.
594    #[inline]
595    pub fn record_model_mapping(&mut self, original: String, mapped: String) {
596        self.model_mapping_applied = Some((original, mapped));
597    }
598
599    // === Model Routing accessors ===
600
601    /// Check if model-based routing was used to select the upstream.
602    #[inline]
603    pub fn used_model_routing(&self) -> bool {
604        self.model_routing_used
605    }
606
607    /// Get the provider override from model-based routing (if any).
608    #[inline]
609    pub fn inference_provider_override(&self) -> Option<sentinel_config::InferenceProvider> {
610        self.inference_provider_override
611    }
612
613    /// Record model-based routing result.
614    ///
615    /// Called when model-based routing selects an upstream based on the model name.
616    #[inline]
617    pub fn record_model_routing(
618        &mut self,
619        upstream: &str,
620        model: Option<String>,
621        provider_override: Option<sentinel_config::InferenceProvider>,
622    ) {
623        self.upstream = Some(upstream.to_string());
624        self.model_routing_used = true;
625        if model.is_some() {
626            self.inference_model = model;
627        }
628        self.inference_provider_override = provider_override;
629    }
630}
631
632impl Default for RequestContext {
633    fn default() -> Self {
634        Self::new()
635    }
636}
637
638// ============================================================================
639// Tests
640// ============================================================================
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_rate_limit_header_info() {
648        let info = RateLimitHeaderInfo {
649            limit: 100,
650            remaining: 42,
651            reset_at: 1704067200,
652        };
653
654        assert_eq!(info.limit, 100);
655        assert_eq!(info.remaining, 42);
656        assert_eq!(info.reset_at, 1704067200);
657    }
658
659    #[test]
660    fn test_request_context_default() {
661        let ctx = RequestContext::new();
662
663        assert!(ctx.trace_id.is_empty());
664        assert!(ctx.rate_limit_info.is_none());
665        assert!(ctx.route_id.is_none());
666        assert!(ctx.config.is_none());
667    }
668
669    #[test]
670    fn test_request_context_rate_limit_info() {
671        let mut ctx = RequestContext::new();
672
673        // Initially no rate limit info
674        assert!(ctx.rate_limit_info.is_none());
675
676        // Set rate limit info
677        ctx.rate_limit_info = Some(RateLimitHeaderInfo {
678            limit: 50,
679            remaining: 25,
680            reset_at: 1704067300,
681        });
682
683        assert!(ctx.rate_limit_info.is_some());
684        let info = ctx.rate_limit_info.as_ref().unwrap();
685        assert_eq!(info.limit, 50);
686        assert_eq!(info.remaining, 25);
687        assert_eq!(info.reset_at, 1704067300);
688    }
689
690    #[test]
691    fn test_request_context_elapsed() {
692        let ctx = RequestContext::new();
693
694        // Elapsed time should be very small (less than 1 second)
695        let elapsed = ctx.elapsed();
696        assert!(elapsed.as_secs() < 1);
697    }
698
699    #[test]
700    fn test_request_context_setters() {
701        let mut ctx = RequestContext::new();
702
703        ctx.set_trace_id("trace-123");
704        assert_eq!(ctx.trace_id(), "trace-123");
705        assert_eq!(ctx.correlation_id(), "trace-123");
706
707        ctx.set_route_id("my-route");
708        assert_eq!(ctx.route_id(), Some("my-route"));
709
710        ctx.set_upstream("backend-pool");
711        assert_eq!(ctx.upstream(), Some("backend-pool"));
712
713        ctx.inc_upstream_attempts();
714        ctx.inc_upstream_attempts();
715        assert_eq!(ctx.upstream_attempts(), 2);
716
717        ctx.set_response_bytes(1024);
718        assert_eq!(ctx.response_bytes(), 1024);
719    }
720
721    #[test]
722    fn test_fallback_tracking() {
723        let mut ctx = RequestContext::new();
724
725        // Initially no fallback
726        assert_eq!(ctx.fallback_attempt(), 0);
727        assert!(!ctx.used_fallback());
728        assert!(ctx.tried_upstreams().is_empty());
729        assert!(ctx.fallback_reason().is_none());
730        assert!(ctx.original_upstream().is_none());
731
732        // Set initial upstream
733        ctx.set_upstream("openai-primary");
734
735        // Record first fallback
736        ctx.record_fallback(FallbackReason::HealthCheckFailed, "anthropic-fallback");
737
738        assert_eq!(ctx.fallback_attempt(), 1);
739        assert!(ctx.used_fallback());
740        assert_eq!(ctx.tried_upstreams(), &["openai-primary".to_string()]);
741        assert!(matches!(
742            ctx.fallback_reason(),
743            Some(FallbackReason::HealthCheckFailed)
744        ));
745        assert_eq!(ctx.original_upstream(), Some("openai-primary"));
746        assert_eq!(ctx.upstream(), Some("anthropic-fallback"));
747
748        // Record second fallback
749        ctx.record_fallback(
750            FallbackReason::ErrorCode(503),
751            "local-gpu",
752        );
753
754        assert_eq!(ctx.fallback_attempt(), 2);
755        assert_eq!(
756            ctx.tried_upstreams(),
757            &["openai-primary".to_string(), "anthropic-fallback".to_string()]
758        );
759        assert!(matches!(
760            ctx.fallback_reason(),
761            Some(FallbackReason::ErrorCode(503))
762        ));
763        // Original upstream should still be the first one
764        assert_eq!(ctx.original_upstream(), Some("openai-primary"));
765        assert_eq!(ctx.upstream(), Some("local-gpu"));
766    }
767
768    #[test]
769    fn test_model_mapping_tracking() {
770        let mut ctx = RequestContext::new();
771
772        assert!(ctx.model_mapping_applied().is_none());
773
774        ctx.record_model_mapping("gpt-4".to_string(), "claude-3-opus".to_string());
775
776        let mapping = ctx.model_mapping_applied().unwrap();
777        assert_eq!(mapping.0, "gpt-4");
778        assert_eq!(mapping.1, "claude-3-opus");
779    }
780
781    #[test]
782    fn test_fallback_reason_display() {
783        assert_eq!(
784            FallbackReason::HealthCheckFailed.to_string(),
785            "health_check_failed"
786        );
787        assert_eq!(
788            FallbackReason::BudgetExhausted.to_string(),
789            "budget_exhausted"
790        );
791        assert_eq!(
792            FallbackReason::LatencyThreshold {
793                observed_ms: 5500,
794                threshold_ms: 5000
795            }
796            .to_string(),
797            "latency_threshold_5500ms_exceeded_5000ms"
798        );
799        assert_eq!(FallbackReason::ErrorCode(502).to_string(), "error_code_502");
800        assert_eq!(
801            FallbackReason::ConnectionError("timeout".to_string()).to_string(),
802            "connection_error_timeout"
803        );
804    }
805}