structured_proxy/transcode/
metadata.rs1use std::time::Duration;
9
10use axum::http::HeaderMap;
11use tonic::metadata::MetadataMap;
12
13pub fn http_headers_to_grpc_metadata(
19 headers: &HeaderMap,
20 forwarded_headers: &[String],
21) -> MetadataMap {
22 let mut metadata = MetadataMap::new();
23
24 for header_name in forwarded_headers {
25 if let Some(value) = headers.get(header_name.as_str()) {
26 insert_ascii(&mut metadata, header_name, value.as_bytes());
27 }
28 }
29
30 inject_trace_context(&mut metadata, headers);
31
32 metadata
33}
34
35fn insert_ascii(metadata: &mut MetadataMap, key: &str, value: &[u8]) {
37 if let (Ok(k), Ok(v)) = (
38 key.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>(),
39 tonic::metadata::AsciiMetadataValue::try_from(value),
40 ) {
41 metadata.insert(k, v);
42 }
43}
44
45fn inject_trace_context(metadata: &mut MetadataMap, headers: &HeaderMap) {
51 if let Some(tp) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
52 if is_valid_traceparent(tp) {
53 insert_ascii(metadata, "traceparent", tp.as_bytes());
54 if let Some(ts) = headers.get("tracestate") {
56 insert_ascii(metadata, "tracestate", ts.as_bytes());
57 }
58 return;
59 }
60 }
61 if let Some(tp) = new_traceparent() {
62 insert_ascii(metadata, "traceparent", tp.as_bytes());
63 }
64}
65
66fn is_valid_traceparent(tp: &str) -> bool {
73 let parts: Vec<&str> = tp.split('-').collect();
74 if parts.len() < 4 {
75 return false;
76 }
77 let (version, trace_id, parent_id, flags) = (parts[0], parts[1], parts[2], parts[3]);
78 if version == "00" && parts.len() != 4 {
79 return false;
80 }
81 let is_hex = |s: &str, len: usize| s.len() == len && s.bytes().all(|b| b.is_ascii_hexdigit());
82 is_hex(version, 2)
83 && !version.eq_ignore_ascii_case("ff")
84 && is_hex(trace_id, 32)
85 && is_hex(parent_id, 16)
86 && is_hex(flags, 2)
87 && trace_id.bytes().any(|b| b != b'0')
88 && parent_id.bytes().any(|b| b != b'0')
89}
90
91fn new_traceparent() -> Option<String> {
94 let mut buf = [0u8; 24];
95 getrandom::fill(&mut buf).ok()?;
96 let trace_id = hex(&buf[..16]);
97 let span_id = hex(&buf[16..]);
98 Some(format!("00-{trace_id}-{span_id}-01"))
99}
100
101fn hex(bytes: &[u8]) -> String {
103 use std::fmt::Write;
104 let mut s = String::with_capacity(bytes.len() * 2);
105 for b in bytes {
106 let _ = write!(s, "{b:02x}");
107 }
108 s
109}
110
111pub fn apply_request_deadline<T>(
118 request: &mut tonic::Request<T>,
119 headers: &HeaderMap,
120) -> Option<Duration> {
121 let timeout = headers
122 .get("grpc-timeout")
123 .and_then(|v| v.to_str().ok())
124 .and_then(parse_grpc_timeout)?;
125 request.set_timeout(timeout);
126 Some(timeout)
127}
128
129fn parse_grpc_timeout(value: &str) -> Option<Duration> {
137 let value = value.trim();
138 let (digits, unit) = value.split_at(value.len().checked_sub(1)?);
139 if digits.is_empty() || digits.len() > 8 {
141 return None;
142 }
143 let n: u64 = digits.parse().ok()?;
144 let dur = match unit {
147 "H" => Duration::from_secs(n * 3600),
148 "M" => Duration::from_secs(n * 60),
149 "S" => Duration::from_secs(n),
150 "m" => Duration::from_millis(n),
151 "u" => Duration::from_micros(n),
152 "n" => Duration::from_nanos(n),
153 _ => return None,
154 };
155 if dur.is_zero() {
156 return None;
157 }
158 Some(dur)
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use axum::http::HeaderValue;
165
166 fn default_headers() -> Vec<String> {
167 vec![
168 "authorization".into(),
169 "dpop".into(),
170 "x-request-id".into(),
171 "x-forwarded-for".into(),
172 "x-forwarded-proto".into(),
173 "x-real-ip".into(),
174 "accept-language".into(),
175 "user-agent".into(),
176 "idempotency-key".into(),
177 ]
178 }
179
180 #[test]
181 fn test_authorization_forwarded() {
182 let mut headers = HeaderMap::new();
183 headers.insert("authorization", HeaderValue::from_static("Bearer tok123"));
184 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
185 assert_eq!(
186 meta.get("authorization").unwrap().to_str().unwrap(),
187 "Bearer tok123"
188 );
189 }
190
191 #[test]
192 fn test_multiple_headers_forwarded() {
193 let mut headers = HeaderMap::new();
194 headers.insert("authorization", HeaderValue::from_static("Bearer tok"));
195 headers.insert("x-request-id", HeaderValue::from_static("req-42"));
196 headers.insert("accept-language", HeaderValue::from_static("en-US"));
197 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
198 assert_eq!(
199 meta.get("authorization").unwrap().to_str().unwrap(),
200 "Bearer tok"
201 );
202 assert_eq!(
203 meta.get("x-request-id").unwrap().to_str().unwrap(),
204 "req-42"
205 );
206 assert_eq!(
207 meta.get("accept-language").unwrap().to_str().unwrap(),
208 "en-US"
209 );
210 }
211
212 #[test]
213 fn test_unknown_headers_not_forwarded() {
214 let mut headers = HeaderMap::new();
215 headers.insert("x-custom-header", HeaderValue::from_static("value"));
216 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
217 assert!(meta.get("x-custom-header").is_none());
218 }
219
220 #[test]
221 fn test_custom_forwarded_headers() {
222 let mut headers = HeaderMap::new();
223 headers.insert("x-custom-header", HeaderValue::from_static("value"));
224 let forwarded = vec!["x-custom-header".to_string()];
225 let meta = http_headers_to_grpc_metadata(&headers, &forwarded);
226 assert_eq!(
227 meta.get("x-custom-header").unwrap().to_str().unwrap(),
228 "value"
229 );
230 }
231
232 #[test]
233 fn test_empty_headers_still_inject_traceparent() {
234 let headers = HeaderMap::new();
237 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
238 let tp = meta.get("traceparent").unwrap().to_str().unwrap();
239 assert!(is_valid_traceparent(tp), "bad traceparent: {tp}");
240 assert!(meta.get("authorization").is_none());
242 }
243
244 #[test]
245 fn traceparent_is_forwarded_when_present() {
246 let mut headers = HeaderMap::new();
247 let incoming = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
248 headers.insert("traceparent", HeaderValue::from_static(incoming));
249 headers.insert("tracestate", HeaderValue::from_static("vendor=value"));
250 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
251 assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
252 assert_eq!(
253 meta.get("tracestate").unwrap().to_str().unwrap(),
254 "vendor=value"
255 );
256 }
257
258 #[test]
259 fn synthesized_traceparent_is_unique_per_call() {
260 let headers = HeaderMap::new();
261 let a = http_headers_to_grpc_metadata(&headers, &[]);
262 let b = http_headers_to_grpc_metadata(&headers, &[]);
263 assert_ne!(
264 a.get("traceparent").unwrap().to_str().unwrap(),
265 b.get("traceparent").unwrap().to_str().unwrap()
266 );
267 }
268
269 #[test]
270 fn grpc_timeout_parses_each_unit() {
271 assert_eq!(parse_grpc_timeout("5S"), Some(Duration::from_secs(5)));
272 assert_eq!(parse_grpc_timeout("100m"), Some(Duration::from_millis(100)));
273 assert_eq!(parse_grpc_timeout("2M"), Some(Duration::from_secs(120)));
274 assert_eq!(parse_grpc_timeout("1H"), Some(Duration::from_secs(3600)));
275 assert_eq!(parse_grpc_timeout("250u"), Some(Duration::from_micros(250)));
276 assert_eq!(parse_grpc_timeout("9n"), Some(Duration::from_nanos(9)));
277 }
278
279 #[test]
280 fn grpc_timeout_rejects_malformed() {
281 assert_eq!(parse_grpc_timeout(""), None);
282 assert_eq!(parse_grpc_timeout("S"), None);
283 assert_eq!(parse_grpc_timeout("10X"), None);
284 assert_eq!(parse_grpc_timeout("abcS"), None);
285 }
286
287 #[test]
288 fn grpc_timeout_rejects_zero_duration() {
289 assert_eq!(parse_grpc_timeout("0S"), None);
292 assert_eq!(parse_grpc_timeout("0m"), None);
293 assert_eq!(parse_grpc_timeout("0n"), None);
294 }
295
296 #[test]
297 fn grpc_timeout_enforces_8_digit_limit() {
298 assert_eq!(
300 parse_grpc_timeout("99999999S"),
301 Some(Duration::from_secs(99_999_999))
302 );
303 assert_eq!(parse_grpc_timeout("999999999S"), None); }
305
306 #[test]
307 fn versioned_traceparent_is_forwarded() {
308 let incoming = "01-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
311 let mut headers = HeaderMap::new();
312 headers.insert("traceparent", HeaderValue::from_static(incoming));
313 let meta = http_headers_to_grpc_metadata(&headers, &[]);
314 assert_eq!(meta.get("traceparent").unwrap().to_str().unwrap(), incoming);
315 }
316
317 #[test]
318 fn ff_version_traceparent_is_rejected() {
319 let invalid = "ff-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
321 let mut headers = HeaderMap::new();
322 headers.insert("traceparent", HeaderValue::from_static(invalid));
323 let meta = http_headers_to_grpc_metadata(&headers, &[]);
324 let tp = meta.get("traceparent").unwrap().to_str().unwrap();
325 assert_ne!(tp, invalid);
326 assert!(is_valid_traceparent(tp));
327 }
328
329 #[test]
330 fn malformed_or_zero_traceparent_is_not_forwarded() {
331 let zeros = "00-00000000000000000000000000000000-0000000000000000-01";
334 let mut headers = HeaderMap::new();
335 headers.insert("traceparent", HeaderValue::from_static(zeros));
336 let meta = http_headers_to_grpc_metadata(&headers, &[]);
337 let tp = meta.get("traceparent").unwrap().to_str().unwrap();
338 assert_ne!(tp, zeros);
339 assert!(
340 is_valid_traceparent(tp),
341 "synthesized traceparent invalid: {tp}"
342 );
343 }
344
345 #[test]
346 fn apply_request_deadline_sets_timeout_from_header() {
347 let mut headers = HeaderMap::new();
348 headers.insert("grpc-timeout", HeaderValue::from_static("3S"));
349 let mut req = tonic::Request::new(());
350 assert_eq!(
351 apply_request_deadline(&mut req, &headers),
352 Some(Duration::from_secs(3))
353 );
354 }
355
356 #[test]
357 fn apply_request_deadline_noop_without_header() {
358 let headers = HeaderMap::new();
359 let mut req = tonic::Request::new(());
360 assert_eq!(apply_request_deadline(&mut req, &headers), None);
361 }
362
363 #[test]
364 fn test_dpop_forwarded() {
365 let mut headers = HeaderMap::new();
366 headers.insert("dpop", HeaderValue::from_static("eyJ0eXAiOiJkcG9wK2p3dCJ9"));
367 let meta = http_headers_to_grpc_metadata(&headers, &default_headers());
368 assert!(meta.get("dpop").is_some());
369 }
370}