Skip to main content

zentinel_proxy/
http_helpers.rs

1//! HTTP request and response helpers for Zentinel proxy
2//!
3//! This module provides utilities for:
4//! - Extracting request information from Pingora sessions
5//! - Writing HTTP responses to Pingora sessions
6//! - Trace ID extraction from headers
7//!
8//! These helpers reduce boilerplate in the main proxy logic and ensure
9//! consistent handling of HTTP operations.
10
11use bytes::Bytes;
12use http::Response;
13use http_body_util::{BodyExt, Full};
14use pingora::http::{RequestHeader, ResponseHeader};
15use pingora::prelude::*;
16use pingora::proxy::Session;
17use std::collections::HashMap;
18
19use crate::routing::RequestInfo;
20use crate::trace_id::{generate_for_format, TraceIdFormat};
21
22// ============================================================================
23// Request Helpers
24// ============================================================================
25
26/// Owned request information for external use (non-hot-path)
27///
28/// This struct owns its data and is used when lifetime management of
29/// `RequestInfo<'a>` is impractical (e.g., storing beyond request scope).
30#[derive(Debug, Clone)]
31pub struct OwnedRequestInfo {
32    pub method: String,
33    pub path: String,
34    pub host: String,
35    pub headers: HashMap<String, String>,
36    pub query_params: HashMap<String, String>,
37}
38
39/// Extract the request host from a Pingora `RequestHeader`.
40///
41/// Resolves the host using a protocol-aware fallback chain so route matching
42/// works consistently for HTTP/1.1, HTTP/2, and absolute-URI requests:
43///
44/// 1. `uri.host()` — populated by Pingora from the HTTP/2 `:authority`
45///    pseudo-header and from absolute-form HTTP/1.1 request URIs.
46/// 2. The `Host` header — used by HTTP/1.1 requests with a relative URI.
47///
48/// Returns `""` if neither source provides a host. Any port suffix is left
49/// intact; downstream matchers (e.g. `HostMatcher::matches`) are responsible
50/// for stripping it per Gateway API semantics.
51pub fn extract_request_host(req_header: &RequestHeader) -> &str {
52    if let Some(host) = req_header.uri.host() {
53        return host;
54    }
55    req_header
56        .headers
57        .get("host")
58        .and_then(|h| h.to_str().ok())
59        .unwrap_or("")
60}
61
62/// Extract request info from a Pingora session
63///
64/// Builds an `OwnedRequestInfo` struct from the session's request headers.
65/// This function allocates all fields.
66///
67/// For the hot path, use `RequestInfo::new()` with
68/// `with_headers()`/`with_query_params()` only when needed.
69///
70/// # Example
71///
72/// ```ignore
73/// let request_info = extract_request_info(session);
74/// ```
75pub fn extract_request_info(session: &Session) -> OwnedRequestInfo {
76    let req_header = session.req_header();
77
78    let headers = RequestInfo::build_headers(req_header.headers.iter());
79    let host = extract_request_host(req_header).to_string();
80    let path = req_header.uri.path().to_string();
81    let method = req_header.method.as_str().to_string();
82
83    OwnedRequestInfo {
84        method,
85        path: path.clone(),
86        host,
87        headers,
88        query_params: RequestInfo::parse_query_params(&path),
89    }
90}
91
92/// Extract or generate a trace ID from request headers
93///
94/// Looks for existing trace ID headers in order of preference:
95/// 1. `X-Trace-Id`
96/// 2. `X-Correlation-Id`
97/// 3. `X-Request-Id`
98///
99/// If none are found, generates a new TinyFlake trace ID (11 chars).
100/// See [`crate::trace_id`] module for TinyFlake format details.
101///
102/// # Example
103///
104/// ```ignore
105/// let trace_id = get_or_create_trace_id(session, TraceIdFormat::TinyFlake);
106/// tracing::info!(trace_id = %trace_id, "Processing request");
107/// ```
108pub fn get_or_create_trace_id(session: &Session, format: TraceIdFormat) -> String {
109    let req_header = session.req_header();
110
111    // Check for existing trace ID headers (in order of preference)
112    const TRACE_HEADERS: [&str; 3] = ["x-trace-id", "x-correlation-id", "x-request-id"];
113
114    for header_name in &TRACE_HEADERS {
115        if let Some(value) = req_header.headers.get(*header_name) {
116            if let Ok(id) = value.to_str() {
117                if !id.is_empty() {
118                    return id.to_string();
119                }
120            }
121        }
122    }
123
124    // Generate new trace ID using configured format
125    generate_for_format(format)
126}
127
128/// Extract or generate a trace ID (convenience function using TinyFlake default)
129///
130/// This is a convenience wrapper around [`get_or_create_trace_id`] that uses
131/// the default TinyFlake format.
132#[inline]
133pub fn get_or_create_trace_id_default(session: &Session) -> String {
134    get_or_create_trace_id(session, TraceIdFormat::default())
135}
136
137// ============================================================================
138// Response Helpers
139// ============================================================================
140
141/// Write an HTTP response to a Pingora session
142///
143/// Handles the conversion from `http::Response<Full<Bytes>>` to Pingora's
144/// format and writes it to the session.
145///
146/// # Arguments
147///
148/// * `session` - The Pingora session to write to
149/// * `response` - The HTTP response to write
150/// * `keepalive_secs` - Keepalive timeout in seconds (None = disable keepalive)
151///
152/// # Returns
153///
154/// Returns `Ok(())` on success or an error if writing fails.
155///
156/// # Example
157///
158/// ```ignore
159/// let response = Response::builder()
160///     .status(200)
161///     .body(Full::new(Bytes::from("OK")))?;
162/// write_response(session, response, Some(60)).await?;
163/// ```
164pub async fn write_response(
165    session: &mut Session,
166    response: Response<Full<Bytes>>,
167    keepalive_secs: Option<u64>,
168) -> Result<(), Box<Error>> {
169    let status = response.status().as_u16();
170
171    // Collect headers to owned strings to avoid lifetime issues
172    let headers_owned: Vec<(String, String)> = response
173        .headers()
174        .iter()
175        .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
176        .collect();
177
178    // Extract body bytes
179    let full_body = response.into_body();
180    let body_bytes: Bytes = BodyExt::collect(full_body)
181        .await
182        .map(|collected| collected.to_bytes())
183        .unwrap_or_default();
184
185    // Build Pingora response header
186    let mut resp_header = ResponseHeader::build(status, None)?;
187    for (key, value) in headers_owned {
188        resp_header.insert_header(key, &value)?;
189    }
190
191    // Write response to session
192    session.set_keepalive(keepalive_secs);
193    session
194        .write_response_header(Box::new(resp_header), false)
195        .await?;
196    session.write_response_body(Some(body_bytes), true).await?;
197
198    Ok(())
199}
200
201/// Write an error response to a Pingora session
202///
203/// Convenience wrapper for error responses with status code, body, and content type.
204///
205/// # Arguments
206///
207/// * `session` - The Pingora session to write to
208/// * `status` - HTTP status code
209/// * `body` - Response body as string
210/// * `content_type` - Content-Type header value
211pub async fn write_error(
212    session: &mut Session,
213    status: u16,
214    body: &str,
215    content_type: &str,
216) -> Result<(), Box<Error>> {
217    let mut resp_header = ResponseHeader::build(status, None)?;
218    resp_header.insert_header("Content-Type", content_type)?;
219    resp_header.insert_header("Content-Length", body.len().to_string())?;
220
221    session.set_keepalive(None);
222    session
223        .write_response_header(Box::new(resp_header), false)
224        .await?;
225    session
226        .write_response_body(Some(Bytes::copy_from_slice(body.as_bytes())), true)
227        .await?;
228
229    Ok(())
230}
231
232/// Write a plain text error response
233///
234/// Shorthand for `write_error` with `text/plain; charset=utf-8` content type.
235pub async fn write_text_error(
236    session: &mut Session,
237    status: u16,
238    message: &str,
239) -> Result<(), Box<Error>> {
240    write_error(session, status, message, "text/plain; charset=utf-8").await
241}
242
243/// Write a JSON error response
244///
245/// Creates a JSON object with `error` and optional `message` fields.
246///
247/// # Example
248///
249/// ```ignore
250/// // Produces: {"error":"not_found","message":"Resource does not exist"}
251/// write_json_error(session, 404, "not_found", Some("Resource does not exist")).await?;
252/// ```
253pub async fn write_json_error(
254    session: &mut Session,
255    status: u16,
256    error: &str,
257    message: Option<&str>,
258) -> Result<(), Box<Error>> {
259    let body = match message {
260        Some(msg) => format!(r#"{{"error":"{}","message":"{}"}}"#, error, msg),
261        None => format!(r#"{{"error":"{}"}}"#, error),
262    };
263    write_error(session, status, &body, "application/json").await
264}
265
266/// Write a rate limit error response with standard rate limit headers
267///
268/// Includes the following headers:
269/// - `X-RateLimit-Limit`: Maximum requests per window
270/// - `X-RateLimit-Remaining`: Remaining requests in current window
271/// - `X-RateLimit-Reset`: Unix timestamp when the window resets
272/// - `Retry-After`: Seconds until the client should retry (for 429 responses)
273///
274/// # Arguments
275///
276/// * `session` - The Pingora session to write to
277/// * `status` - HTTP status code (typically 429)
278/// * `body` - Response body as string
279/// * `limit` - Maximum requests allowed per window
280/// * `remaining` - Remaining requests in current window
281/// * `reset_at` - Unix timestamp when the window resets
282/// * `retry_after` - Seconds until client should retry
283pub async fn write_rate_limit_error(
284    session: &mut Session,
285    status: u16,
286    body: &str,
287    limit: u32,
288    remaining: u32,
289    reset_at: u64,
290    retry_after: u64,
291) -> Result<(), Box<Error>> {
292    let mut resp_header = ResponseHeader::build(status, None)?;
293    resp_header.insert_header("Content-Type", "text/plain; charset=utf-8")?;
294    resp_header.insert_header("Content-Length", body.len().to_string())?;
295
296    // Add standard rate limit headers
297    resp_header.insert_header("X-RateLimit-Limit", limit.to_string())?;
298    resp_header.insert_header("X-RateLimit-Remaining", remaining.to_string())?;
299    resp_header.insert_header("X-RateLimit-Reset", reset_at.to_string())?;
300
301    // Add Retry-After header (seconds until reset)
302    if retry_after > 0 {
303        resp_header.insert_header("Retry-After", retry_after.to_string())?;
304    }
305
306    session.set_keepalive(None);
307    session
308        .write_response_header(Box::new(resp_header), false)
309        .await?;
310    session
311        .write_response_body(Some(Bytes::copy_from_slice(body.as_bytes())), true)
312        .await?;
313
314    Ok(())
315}
316
317// ============================================================================
318// Tests
319// ============================================================================
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    fn req(uri: &str, host_header: Option<&str>) -> RequestHeader {
326        let mut h = RequestHeader::build("GET", b"/", None).unwrap();
327        h.set_uri(uri.parse().unwrap());
328        if let Some(v) = host_header {
329            h.insert_header("host", v).unwrap();
330        }
331        h
332    }
333
334    #[test]
335    fn extract_host_prefers_uri_host_for_absolute_uri() {
336        // HTTP/1.1 absolute-form request — uri.host() is populated.
337        let h = req("http://example.com/path", Some("other.example.org"));
338        assert_eq!(extract_request_host(&h), "example.com");
339    }
340
341    #[test]
342    fn extract_host_falls_back_to_header_for_relative_uri() {
343        // HTTP/1.1 relative-form — uri.host() is None, must use Host header.
344        let h = req(
345            "/_matrix/federation/v1/send/123",
346            Some("im.example.com:443"),
347        );
348        assert_eq!(extract_request_host(&h), "im.example.com:443");
349    }
350
351    #[test]
352    fn extract_host_returns_empty_when_no_host_anywhere() {
353        let h = req("/path", None);
354        assert_eq!(extract_request_host(&h), "");
355    }
356
357    #[test]
358    fn extract_host_uses_uri_when_header_missing() {
359        // Simulates the HTTP/2 case where Pingora parses :authority into uri.
360        let h = req("http://api.example.com/v1", None);
361        assert_eq!(extract_request_host(&h), "api.example.com");
362    }
363
364    // Trace ID generation tests are in crate::trace_id module.
365    // Integration tests for get_or_create_trace_id require mocking Pingora session.
366    // See crates/proxy/tests/ for integration test examples.
367}