Skip to main content

shift_proxy/
body.rs

1//! Request body extraction with transparent decompression.
2//!
3//! Clients like Codex CLI send compressed request bodies (gzip or zstd).
4//! Axum's `String` extractor rejects these because raw compressed bytes
5//! aren't valid UTF-8. This module extracts the raw `Bytes`, decompresses
6//! if needed, and converts to a UTF-8 string.
7
8use axum::body::Bytes;
9use axum::http::HeaderMap;
10use flate2::read::GzDecoder;
11use std::io::Read;
12use zstd::stream::read::Decoder as ZstdDecoder;
13
14/// Maximum decompressed body size: 512 MB.
15///
16/// The `DefaultBodyLimit` in `mod.rs` caps the *compressed* input at 200 MB,
17/// but a crafted gzip/zstd payload can expand ~1000x. This limit caps the
18/// *decompressed* output to prevent memory exhaustion.
19const MAX_DECOMPRESSED_SIZE: u64 = 512 * 1024 * 1024;
20
21/// Extract the request body as a UTF-8 string, decompressing if the
22/// `Content-Encoding` header indicates compression.
23///
24/// Supported encodings:
25/// - `gzip` — decompressed via flate2
26/// - `zstd` — decompressed via the zstd crate (used by Codex CLI)
27/// - (none / identity) — passed through as-is
28///
29/// When no `Content-Encoding` header is present, sniffs gzip (0x1f 0x8b)
30/// and zstd (0x28 0xb5 0x2f 0xfd) magic bytes as a fallback.
31pub fn extract_body(headers: &HeaderMap, raw: Bytes) -> Result<String, String> {
32    let encoding = headers
33        .get("content-encoding")
34        .and_then(|v| v.to_str().ok())
35        .unwrap_or("");
36
37    // Detect compression via header first, then magic bytes only when
38    // no Content-Encoding header is set.
39    let no_encoding = encoding.is_empty();
40    let has_gzip_magic = raw.len() >= 2 && raw[0] == 0x1f && raw[1] == 0x8b;
41    let has_zstd_magic =
42        raw.len() >= 4 && raw[0] == 0x28 && raw[1] == 0xb5 && raw[2] == 0x2f && raw[3] == 0xfd;
43
44    let is_gzip = encoding.eq_ignore_ascii_case("gzip") || (no_encoding && has_gzip_magic);
45    let is_zstd = encoding.eq_ignore_ascii_case("zstd") || (no_encoding && has_zstd_magic);
46
47    let bytes = if is_gzip {
48        let mut decoded = Vec::new();
49        GzDecoder::new(&raw[..])
50            .take(MAX_DECOMPRESSED_SIZE)
51            .read_to_end(&mut decoded)
52            .map_err(|e| format!("gzip decode error: {e}"))?;
53        decoded
54    } else if is_zstd {
55        let mut decoded = Vec::new();
56        ZstdDecoder::new(&raw[..])
57            .map_err(|e| format!("zstd init error: {e}"))?
58            .take(MAX_DECOMPRESSED_SIZE)
59            .read_to_end(&mut decoded)
60            .map_err(|e| format!("zstd decode error: {e}"))?;
61        decoded
62    } else {
63        raw.to_vec()
64    };
65
66    String::from_utf8(bytes).map_err(|e| format!("invalid UTF-8: {e}"))
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use flate2::write::GzEncoder;
73    use flate2::Compression;
74    use std::io::Write;
75
76    #[test]
77    fn test_extract_plain_body() {
78        let headers = HeaderMap::new();
79        let body = Bytes::from(r#"{"model":"gpt-4o"}"#);
80        let result = extract_body(&headers, body).unwrap();
81        assert_eq!(result, r#"{"model":"gpt-4o"}"#);
82    }
83
84    #[test]
85    fn test_extract_gzip_body() {
86        let original = r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"hello"}]}"#;
87
88        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
89        encoder.write_all(original.as_bytes()).unwrap();
90        let compressed = encoder.finish().unwrap();
91
92        let mut headers = HeaderMap::new();
93        headers.insert("content-encoding", "gzip".parse().unwrap());
94
95        let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
96        assert_eq!(result, original);
97    }
98
99    #[test]
100    fn test_extract_gzip_body_case_insensitive() {
101        let original = "hello world";
102
103        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
104        encoder.write_all(original.as_bytes()).unwrap();
105        let compressed = encoder.finish().unwrap();
106
107        let mut headers = HeaderMap::new();
108        headers.insert("content-encoding", "Gzip".parse().unwrap());
109
110        let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
111        assert_eq!(result, original);
112    }
113
114    #[test]
115    fn test_extract_gzip_body_magic_bytes_no_header() {
116        // Codex CLI sends gzip without Content-Encoding header
117        let original = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"test"}]}"#;
118
119        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
120        encoder.write_all(original.as_bytes()).unwrap();
121        let compressed = encoder.finish().unwrap();
122
123        // Verify gzip magic bytes are present
124        assert_eq!(compressed[0], 0x1f);
125        assert_eq!(compressed[1], 0x8b);
126
127        // No Content-Encoding header — should still decompress via magic byte detection
128        let headers = HeaderMap::new();
129        let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
130        assert_eq!(result, original);
131    }
132
133    #[test]
134    fn test_extract_zstd_body() {
135        let original = r#"{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}]}"#;
136        let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
137
138        let mut headers = HeaderMap::new();
139        headers.insert("content-encoding", "zstd".parse().unwrap());
140
141        let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
142        assert_eq!(result, original);
143    }
144
145    #[test]
146    fn test_extract_zstd_body_magic_bytes_no_header() {
147        let original = r#"{"model":"gpt-5.4","messages":[{"role":"user","content":"test"}]}"#;
148        let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
149
150        // Verify zstd magic bytes
151        assert_eq!(compressed[0], 0x28);
152        assert_eq!(compressed[1], 0xb5);
153        assert_eq!(compressed[2], 0x2f);
154        assert_eq!(compressed[3], 0xfd);
155
156        // No Content-Encoding header
157        let headers = HeaderMap::new();
158        let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
159        assert_eq!(result, original);
160    }
161
162    #[test]
163    fn test_magic_bytes_ignored_when_encoding_header_set() {
164        // If Content-Encoding is explicitly set to something unsupported,
165        // magic byte sniffing should NOT fire — respect the header.
166        let original = r#"{"test": true}"#;
167        let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
168
169        // Set a bogus Content-Encoding — magic bytes should NOT trigger zstd
170        let mut headers = HeaderMap::new();
171        headers.insert("content-encoding", "br".parse().unwrap());
172
173        let result = extract_body(&headers, Bytes::from(compressed));
174        // Should fail as invalid UTF-8 because brotli isn't supported and
175        // magic byte fallback is disabled when a header is present
176        assert!(result.is_err());
177    }
178
179    #[test]
180    fn test_extract_invalid_utf8() {
181        let headers = HeaderMap::new();
182        // Invalid UTF-8 sequence
183        let body = Bytes::from(vec![0xff, 0xfe, 0xfd]);
184        let result = extract_body(&headers, body);
185        assert!(result.is_err());
186        assert!(result.unwrap_err().contains("invalid UTF-8"));
187    }
188
189    #[test]
190    fn test_extract_invalid_gzip() {
191        let mut headers = HeaderMap::new();
192        headers.insert("content-encoding", "gzip".parse().unwrap());
193        // Not valid gzip data
194        let body = Bytes::from(vec![0x00, 0x01, 0x02, 0x03]);
195        let result = extract_body(&headers, body);
196        assert!(result.is_err());
197        assert!(result.unwrap_err().contains("gzip decode error"));
198    }
199}