Skip to main content

structured_proxy/transcode/
metadata.rs

1//! HTTP → gRPC metadata, trace-context, and deadline propagation.
2//!
3//! Converts relevant HTTP headers into gRPC `MetadataMap` entries for upstream
4//! calls (forwarded headers are configurable via YAML), propagates W3C
5//! trace-context across the boundary, and carries a client deadline through as
6//! the upstream call timeout.
7
8use std::time::Duration;
9
10use axum::http::HeaderMap;
11use tonic::metadata::MetadataMap;
12
13/// Extract HTTP headers into a gRPC `MetadataMap`.
14///
15/// Forwards the headers listed in `forwarded_headers`, then always propagates
16/// W3C trace-context (forwarding an incoming `traceparent` or synthesizing one
17/// so the upstream joins a single trace across the REST↔gRPC boundary).
18pub fn http_headers_to_grpc_metadata(
19    headers: &HeaderMap,
20    forwarded_headers: &[String],
21) -> MetadataMap {
22    let mut metadata = MetadataMap::new();
23
24    for header_name in forwarded_headers {
25        if let Some(value) = headers.get(header_name.as_str()) {
26            insert_ascii(&mut metadata, header_name, value.as_bytes());
27        }
28    }
29
30    inject_trace_context(&mut metadata, headers);
31
32    metadata
33}
34
35/// Insert an ASCII metadata entry, silently skipping non-ASCII keys/values.
36fn insert_ascii(metadata: &mut MetadataMap, key: &str, value: &[u8]) {
37    if let (Ok(k), Ok(v)) = (
38        key.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>(),
39        tonic::metadata::AsciiMetadataValue::try_from(value),
40    ) {
41        metadata.insert(k, v);
42    }
43}
44
45/// Propagate W3C trace-context into gRPC metadata.
46///
47/// Forwards an incoming `traceparent` (and `tracestate`) only when it is
48/// well-formed per W3C §3.2.2; otherwise (missing or malformed) synthesizes a
49/// fresh one so the upstream always receives a single valid, joinable trace.
50fn inject_trace_context(metadata: &mut MetadataMap, headers: &HeaderMap) {
51    if let Some(tp) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
52        if is_valid_traceparent(tp) {
53            insert_ascii(metadata, "traceparent", tp.as_bytes());
54            // tracestate only travels with the trace it annotates.
55            if let Some(ts) = headers.get("tracestate") {
56                insert_ascii(metadata, "tracestate", ts.as_bytes());
57            }
58            return;
59        }
60    }
61    if let Some(tp) = new_traceparent() {
62        insert_ascii(metadata, "traceparent", tp.as_bytes());
63    }
64}
65
66/// Validate a W3C `traceparent`: `<version>-<32 hex>-<16 hex>-<2 hex>` with a
67/// non-zero trace-id and parent-id (all-zero IDs are forbidden by W3C §3.2.2).
68///
69/// Per W3C §3.2.1 any 2-hex version except `ff` is accepted; future versions may
70/// append extra `-`-delimited fields, while the baseline `00` must be exactly
71/// the four fields.
72fn is_valid_traceparent(tp: &str) -> bool {
73    let parts: Vec<&str> = tp.split('-').collect();
74    if parts.len() < 4 {
75        return false;
76    }
77    let (version, trace_id, parent_id, flags) = (parts[0], parts[1], parts[2], parts[3]);
78    if version == "00" && parts.len() != 4 {
79        return false;
80    }
81    let is_hex = |s: &str, len: usize| s.len() == len && s.bytes().all(|b| b.is_ascii_hexdigit());
82    is_hex(version, 2)
83        && !version.eq_ignore_ascii_case("ff")
84        && is_hex(trace_id, 32)
85        && is_hex(parent_id, 16)
86        && is_hex(flags, 2)
87        && trace_id.bytes().any(|b| b != b'0')
88        && parent_id.bytes().any(|b| b != b'0')
89}
90
91/// Build a fresh W3C `traceparent`: `00-<16-byte trace-id>-<8-byte span-id>-01`
92/// (sampled). Returns `None` only if the system RNG is unavailable.
93fn new_traceparent() -> Option<String> {
94    let mut buf = [0u8; 24];
95    getrandom::fill(&mut buf).ok()?;
96    let trace_id = hex(&buf[..16]);
97    let span_id = hex(&buf[16..]);
98    Some(format!("00-{trace_id}-{span_id}-01"))
99}
100
101/// Lowercase-hex encode a byte slice.
102fn hex(bytes: &[u8]) -> String {
103    use std::fmt::Write;
104    let mut s = String::with_capacity(bytes.len() * 2);
105    for b in bytes {
106        let _ = write!(s, "{b:02x}");
107    }
108    s
109}
110
111/// Apply a client-supplied deadline to the upstream gRPC call.
112///
113/// Reads the gRPC-standard `grpc-timeout` header (`<int><unit>`, unit one of
114/// `H`/`M`/`S`/`m`/`u`/`n`) and sets it as the request timeout. Absent or
115/// malformed values leave the channel default in place. Returns the deadline
116/// that was applied, if any.
117pub fn apply_request_deadline<T>(
118    request: &mut tonic::Request<T>,
119    headers: &HeaderMap,
120) -> Option<Duration> {
121    let timeout = headers
122        .get("grpc-timeout")
123        .and_then(|v| v.to_str().ok())
124        .and_then(parse_grpc_timeout)?;
125    request.set_timeout(timeout);
126    Some(timeout)
127}
128
129/// Parse a gRPC `grpc-timeout` value (`<int><unit>`) into a [`Duration`].
130///
131/// Units: `H` hours, `M` minutes, `S` seconds, `m` milliseconds, `u`
132/// microseconds, `n` nanoseconds. Per the gRPC wire spec the value is at most 8
133/// digits. Returns `None` on a malformed value, an over-long digit run, or a
134/// zero duration (which would expire the call immediately, so the channel
135/// default is used instead).
136fn parse_grpc_timeout(value: &str) -> Option<Duration> {
137    let value = value.trim();
138    let (digits, unit) = value.split_at(value.len().checked_sub(1)?);
139    // The gRPC spec caps TimeoutValue at 8 ASCII digits.
140    if digits.is_empty() || digits.len() > 8 {
141        return None;
142    }
143    let n: u64 = digits.parse().ok()?;
144    // With at most 8 digits, n <= 99_999_999, so n * 3600 < 4e11 << u64::MAX:
145    // the multiplications cannot overflow.
146    let dur = match unit {
147        "H" => Duration::from_secs(n * 3600),
148        "M" => Duration::from_secs(n * 60),
149        "S" => Duration::from_secs(n),
150        "m" => Duration::from_millis(n),
151        "u" => Duration::from_micros(n),
152        "n" => Duration::from_nanos(n),
153        _ => return None,
154    };
155    if dur.is_zero() {
156        return None;
157    }
158    Some(dur)
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use axum::http::HeaderValue;
165
166    fn default_headers() -> Vec<String> {
167        vec![
168            "authorization".into(),
169            "dpop".into(),
170            "x-request-id".into(),
171            "x-forwarded-for".into(),
172            "x-forwarded-proto".into(),
173            "x-real-ip".into(),
174            "accept-language".into(),
175            "user-agent".into(),
176            "idempotency-key".into(),
177        ]
178    }
179
180    #[test]
181    fn test_authorization_forwarded() {
182        let mut headers = HeaderMap::new();
183        headers.insert("authorization", HeaderValue::from_static("Bearer tok123"));
184        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
185        assert_eq!(
186            meta.get("authorization").unwrap().to_str().unwrap(),
187            "Bearer tok123"
188        );
189    }
190
191    #[test]
192    fn test_multiple_headers_forwarded() {
193        let mut headers = HeaderMap::new();
194        headers.insert("authorization", HeaderValue::from_static("Bearer tok"));
195        headers.insert("x-request-id", HeaderValue::from_static("req-42"));
196        headers.insert("accept-language", HeaderValue::from_static("en-US"));
197        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
198        assert_eq!(
199            meta.get("authorization").unwrap().to_str().unwrap(),
200            "Bearer tok"
201        );
202        assert_eq!(
203            meta.get("x-request-id").unwrap().to_str().unwrap(),
204            "req-42"
205        );
206        assert_eq!(
207            meta.get("accept-language").unwrap().to_str().unwrap(),
208            "en-US"
209        );
210    }
211
212    #[test]
213    fn test_unknown_headers_not_forwarded() {
214        let mut headers = HeaderMap::new();
215        headers.insert("x-custom-header", HeaderValue::from_static("value"));
216        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
217        assert!(meta.get("x-custom-header").is_none());
218    }
219
220    #[test]
221    fn test_custom_forwarded_headers() {
222        let mut headers = HeaderMap::new();
223        headers.insert("x-custom-header", HeaderValue::from_static("value"));
224        let forwarded = vec!["x-custom-header".to_string()];
225        let meta = http_headers_to_grpc_metadata(&headers, &forwarded);
226        assert_eq!(
227            meta.get("x-custom-header").unwrap().to_str().unwrap(),
228            "value"
229        );
230    }
231
232    #[test]
233    fn test_empty_headers_still_inject_traceparent() {
234        // No forwarded headers present, but a trace-context is synthesized so
235        // the upstream joins a single trace.
236        let headers = HeaderMap::new();
237        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
238        let tp = meta.get("traceparent").unwrap().to_str().unwrap();
239        assert!(is_valid_traceparent(tp), "bad traceparent: {tp}");
240        // Nothing else leaks in.
241        assert!(meta.get("authorization").is_none());
242    }
243
244    #[test]
245    fn traceparent_is_forwarded_when_present() {
246        let mut headers = HeaderMap::new();
247        let incoming = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
248        headers.insert("traceparent", HeaderValue::from_static(incoming));
249        headers.insert("tracestate", HeaderValue::from_static("vendor=value"));
250        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
251        assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
252        assert_eq!(
253            meta.get("tracestate").unwrap().to_str().unwrap(),
254            "vendor=value"
255        );
256    }
257
258    #[test]
259    fn synthesized_traceparent_is_unique_per_call() {
260        let headers = HeaderMap::new();
261        let a = http_headers_to_grpc_metadata(&headers, &[]);
262        let b = http_headers_to_grpc_metadata(&headers, &[]);
263        assert_ne!(
264            a.get("traceparent").unwrap().to_str().unwrap(),
265            b.get("traceparent").unwrap().to_str().unwrap()
266        );
267    }
268
269    #[test]
270    fn grpc_timeout_parses_each_unit() {
271        assert_eq!(parse_grpc_timeout("5S"), Some(Duration::from_secs(5)));
272        assert_eq!(parse_grpc_timeout("100m"), Some(Duration::from_millis(100)));
273        assert_eq!(parse_grpc_timeout("2M"), Some(Duration::from_secs(120)));
274        assert_eq!(parse_grpc_timeout("1H"), Some(Duration::from_secs(3600)));
275        assert_eq!(parse_grpc_timeout("250u"), Some(Duration::from_micros(250)));
276        assert_eq!(parse_grpc_timeout("9n"), Some(Duration::from_nanos(9)));
277    }
278
279    #[test]
280    fn grpc_timeout_rejects_malformed() {
281        assert_eq!(parse_grpc_timeout(""), None);
282        assert_eq!(parse_grpc_timeout("S"), None);
283        assert_eq!(parse_grpc_timeout("10X"), None);
284        assert_eq!(parse_grpc_timeout("abcS"), None);
285    }
286
287    #[test]
288    fn grpc_timeout_rejects_zero_duration() {
289        // A zero deadline would make tonic's timeout expire immediately, failing
290        // every such request with DEADLINE_EXCEEDED before it reaches upstream.
291        assert_eq!(parse_grpc_timeout("0S"), None);
292        assert_eq!(parse_grpc_timeout("0m"), None);
293        assert_eq!(parse_grpc_timeout("0n"), None);
294    }
295
296    #[test]
297    fn grpc_timeout_enforces_8_digit_limit() {
298        // The gRPC wire spec caps TimeoutValue at 8 digits.
299        assert_eq!(
300            parse_grpc_timeout("99999999S"),
301            Some(Duration::from_secs(99_999_999))
302        );
303        assert_eq!(parse_grpc_timeout("999999999S"), None); // 9 digits
304    }
305
306    #[test]
307    fn versioned_traceparent_is_forwarded() {
308        // W3C 3.2.1 requires accepting future versions (anything but ff); a
309        // valid version-01 header must be propagated, not dropped + resynthesized.
310        let incoming = "01-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
311        let mut headers = HeaderMap::new();
312        headers.insert("traceparent", HeaderValue::from_static(incoming));
313        let meta = http_headers_to_grpc_metadata(&headers, &[]);
314        assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
315    }
316
317    #[test]
318    fn ff_version_traceparent_is_rejected() {
319        // The reserved "ff" version is invalid per W3C and must be replaced.
320        let invalid = "ff-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
321        let mut headers = HeaderMap::new();
322        headers.insert("traceparent", HeaderValue::from_static(invalid));
323        let meta = http_headers_to_grpc_metadata(&headers, &[]);
324        let tp = meta.get("traceparent").unwrap().to_str().unwrap();
325        assert_ne!(tp, invalid);
326        assert!(is_valid_traceparent(tp));
327    }
328
329    #[test]
330    fn malformed_or_zero_traceparent_is_not_forwarded() {
331        // An all-zeros traceparent is invalid per W3C §3.2.2 and must not be
332        // propagated; a fresh one is synthesized instead.
333        let zeros = "00-00000000000000000000000000000000-0000000000000000-01";
334        let mut headers = HeaderMap::new();
335        headers.insert("traceparent", HeaderValue::from_static(zeros));
336        let meta = http_headers_to_grpc_metadata(&headers, &[]);
337        let tp = meta.get("traceparent").unwrap().to_str().unwrap();
338        assert_ne!(tp, zeros);
339        assert!(
340            is_valid_traceparent(tp),
341            "synthesized traceparent invalid: {tp}"
342        );
343    }
344
345    #[test]
346    fn apply_request_deadline_sets_timeout_from_header() {
347        let mut headers = HeaderMap::new();
348        headers.insert("grpc-timeout", HeaderValue::from_static("3S"));
349        let mut req = tonic::Request::new(());
350        assert_eq!(
351            apply_request_deadline(&mut req, &headers),
352            Some(Duration::from_secs(3))
353        );
354    }
355
356    #[test]
357    fn apply_request_deadline_noop_without_header() {
358        let headers = HeaderMap::new();
359        let mut req = tonic::Request::new(());
360        assert_eq!(apply_request_deadline(&mut req, &headers), None);
361    }
362
363    #[test]
364    fn test_dpop_forwarded() {
365        let mut headers = HeaderMap::new();
366        headers.insert("dpop", HeaderValue::from_static("eyJ0eXAiOiJkcG9wK2p3dCJ9"));
367        let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
368        assert!(meta.get("dpop").is_some());
369    }
370}