1use crate::headers::Headers;
9use crate::transport::h2::hpack_impl::{Decoder, Encoder};
10use bytes::Bytes;
11
12fn bytes_eq_ignore_ascii_case(a: &[u8], b: &[u8]) -> bool {
13 a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.eq_ignore_ascii_case(y))
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
21pub enum PseudoHeaderOrder {
22 #[default]
24 Chrome,
25 Firefox,
27 Safari,
29 Standard,
31 Custom([u8; 4]),
33}
34
35impl PseudoHeaderOrder {
36 fn order(&self) -> [usize; 4] {
40 match self {
41 Self::Chrome => [0, 2, 1, 3], Self::Firefox => [0, 3, 1, 2], Self::Safari => [0, 2, 3, 1], Self::Standard => [0, 1, 2, 3], Self::Custom(order) => [
50 order[0] as usize,
51 order[1] as usize,
52 order[2] as usize,
53 order[3] as usize,
54 ],
55 }
56 }
57
58 pub fn akamai_string(&self) -> &'static str {
60 match self {
61 Self::Chrome => "m,s,a,p",
62 Self::Firefox => "m,p,a,s",
63 Self::Safari => "m,s,p,a",
64 Self::Standard => "m,a,s,p",
65 Self::Custom(_) => "custom",
66 }
67 }
68}
69
70pub struct HpackEncoder {
72 encoder: Encoder,
73 pseudo_order: PseudoHeaderOrder,
74}
75
76impl HpackEncoder {
77 pub fn new(pseudo_order: PseudoHeaderOrder) -> Self {
79 Self {
80 encoder: Encoder::new(),
81 pseudo_order,
82 }
83 }
84
85 pub fn chrome() -> Self {
87 Self::new(PseudoHeaderOrder::Chrome)
88 }
89
90 pub fn set_max_table_size(&mut self, size: usize) {
92 self.encoder.set_max_table_size(size);
93 }
94
95 pub fn encode_request(
100 &mut self,
101 method: &str,
102 scheme: &str,
103 authority: &str,
104 path: &str,
105 headers: impl Into<Headers>,
106 ) -> Bytes {
107 let headers = headers.into();
108 let pseudo_headers: [(&[u8], &[u8]); 4] = [
110 (b":method", method.as_bytes()),
111 (b":authority", authority.as_bytes()),
112 (b":scheme", scheme.as_bytes()),
113 (b":path", path.as_bytes()),
114 ];
115
116 let mut all_headers: Vec<(&[u8], &[u8])> = Vec::new();
118
119 let mut valid_headers: Vec<(Vec<u8>, &[u8])> = Vec::with_capacity(headers.len());
123
124 for (name, value) in headers.iter_bytes() {
125 if name.first() == Some(&b':') {
126 continue;
127 }
128
129 if name.is_empty() {
130 continue;
131 }
132 if name.iter().any(|&b| b < 0x21 || (b > 0x7E && b != 0x7F)) {
133 continue;
134 }
135
136 let name_lower = if name.iter().all(|b| b.is_ascii_lowercase()) {
137 name.to_vec()
138 } else {
139 name.iter().map(|b| b.to_ascii_lowercase()).collect()
140 };
141
142 if name_lower == b"connection"
143 || name_lower == b"keep-alive"
144 || name_lower == b"proxy-connection"
145 || name_lower == b"transfer-encoding"
146 || name_lower == b"upgrade"
147 {
148 continue;
149 }
150
151 if name_lower == b"te" && !bytes_eq_ignore_ascii_case(value, b"trailers") {
152 continue;
153 }
154
155 valid_headers.push((name_lower, value));
156 }
157
158 let order = self.pseudo_order.order();
160 for &idx in &order {
161 all_headers.push(pseudo_headers[idx]);
162 }
163
164 for (n, v) in &valid_headers {
166 all_headers.push((n.as_slice(), *v));
167 }
168
169 let encoded = self.encoder.encode(&all_headers);
171 Bytes::from(encoded)
172 }
173
174 pub fn encode_extended_connect_websocket(
179 &mut self,
180 authority: &str,
181 scheme: &str,
182 path: &str,
183 headers: impl Into<Headers>,
184 ) -> Result<Bytes, String> {
185 let headers = headers.into();
186 if authority.is_empty() {
187 return Err(":authority must not be empty".to_string());
188 }
189 if scheme.is_empty() {
190 return Err(":scheme must not be empty".to_string());
191 }
192 if path.is_empty() {
193 return Err(":path must not be empty".to_string());
194 }
195
196 let pseudo_headers: [(&[u8], &[u8]); 5] = [
197 (b":method", b"CONNECT"),
198 (b":protocol", b"websocket"),
199 (b":scheme", scheme.as_bytes()),
200 (b":path", path.as_bytes()),
201 (b":authority", authority.as_bytes()),
202 ];
203
204 let mut valid_headers: Vec<(Vec<u8>, &[u8])> = Vec::with_capacity(headers.len());
205
206 for (name, value) in headers.iter_bytes() {
207 if name.first() == Some(&b':') {
208 return Err(format!(
209 "RFC 8441 user pseudo-header rejected: {}",
210 String::from_utf8_lossy(name)
211 ));
212 }
213
214 if name.is_empty() {
215 return Err("RFC 8441 header name must not be empty".to_string());
216 }
217 if name.iter().any(|&b| b < 0x21 || (b > 0x7E && b != 0x7F)) {
218 return Err(format!(
219 "RFC 8441 invalid header name rejected: {}",
220 String::from_utf8_lossy(name)
221 ));
222 }
223
224 let name_lower = if name.iter().all(|b| b.is_ascii_lowercase()) {
225 name.to_vec()
226 } else {
227 name.iter().map(|b| b.to_ascii_lowercase()).collect()
228 };
229 if matches!(
230 name_lower.as_slice(),
231 b"connection"
232 | b"upgrade"
233 | b"host"
234 | b"sec-websocket-key"
235 | b"sec-websocket-accept"
236 | b"sec-websocket-extensions"
237 | b"keep-alive"
238 | b"proxy-connection"
239 | b"transfer-encoding"
240 ) {
241 return Err(format!(
242 "RFC 8441 forbidden header rejected: {}",
243 String::from_utf8_lossy(&name_lower)
244 ));
245 }
246
247 if name_lower == b"te" && !bytes_eq_ignore_ascii_case(value, b"trailers") {
248 return Err("RFC 8441 forbids TE values other than trailers".to_string());
249 }
250
251 valid_headers.push((name_lower, value));
252 }
253
254 let mut all_headers: Vec<(&[u8], &[u8])> =
255 Vec::with_capacity(pseudo_headers.len() + valid_headers.len());
256 all_headers.extend_from_slice(&pseudo_headers);
257 for (name, value) in &valid_headers {
258 all_headers.push((name.as_slice(), *value));
259 }
260
261 let encoded = self.encoder.encode(&all_headers);
262 Ok(Bytes::from(encoded))
263 }
264
265 pub fn chunk_encoded(encoded: Bytes, max_frame_size: usize) -> (Bytes, Vec<Bytes>) {
273 if encoded.len() <= max_frame_size {
274 return (encoded, Vec::new());
276 }
277
278 let mut chunks: Vec<Bytes> = encoded
280 .chunks(max_frame_size)
281 .map(Bytes::copy_from_slice)
282 .collect();
283
284 let first = chunks.remove(0);
285 (first, chunks)
286 }
287}
288
289pub struct HpackDecoder {
291 decoder: Decoder,
292}
293
294impl HpackDecoder {
295 pub fn new() -> Self {
297 Self {
298 decoder: Decoder::new(),
299 }
300 }
301
302 pub fn set_max_table_size(&mut self, size: usize) {
304 self.decoder.set_max_table_size(size);
305 }
306
307 pub fn decode(&mut self, data: &[u8]) -> Result<Vec<(String, String)>, String> {
309 let mut headers = Vec::new();
310
311 self.decoder
312 .decode_with_cb(data, |name, value| {
313 let name_str = String::from_utf8_lossy(name).into_owned();
314 let value_str = String::from_utf8_lossy(value).into_owned();
315 headers.push((name_str, value_str));
316 })
317 .map_err(|e| format!("HPACK decode error: {:?}", e))?;
318
319 Ok(headers)
320 }
321}
322
323impl Default for HpackDecoder {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_pseudo_order_chrome() {
335 let order = PseudoHeaderOrder::Chrome;
336 assert_eq!(order.akamai_string(), "m,s,a,p");
337 }
338
339 #[test]
340 fn test_pseudo_order_standard() {
341 let order = PseudoHeaderOrder::Standard;
342 assert_eq!(order.akamai_string(), "m,a,s,p");
343 }
344
345 #[test]
346 fn test_encoder_creates_valid_block() {
347 let mut encoder = HpackEncoder::chrome();
348 let block = encoder.encode_request(
349 "GET",
350 "https",
351 "example.com",
352 "/",
353 &Headers::from(vec![("user-agent".to_string(), "test".to_string())]),
354 );
355
356 assert!(!block.is_empty());
358
359 let mut decoder = HpackDecoder::new();
361 let headers = decoder.decode(&block).unwrap();
362
363 assert_eq!(headers.len(), 5);
365
366 assert_eq!(headers[0].0, ":method");
368 assert_eq!(headers[0].1, "GET");
369 assert_eq!(headers[1].0, ":scheme");
370 assert_eq!(headers[1].1, "https");
371 assert_eq!(headers[2].0, ":authority");
372 assert_eq!(headers[2].1, "example.com");
373 assert_eq!(headers[3].0, ":path");
374 assert_eq!(headers[3].1, "/");
375 assert_eq!(headers[4].0, "user-agent");
376 assert_eq!(headers[4].1, "test");
377 }
378
379 #[test]
380 fn test_encoder_standard_order() {
381 let mut encoder = HpackEncoder::new(PseudoHeaderOrder::Standard);
382 let block = encoder.encode_request("GET", "https", "example.com", "/", &Headers::new());
383
384 let mut decoder = HpackDecoder::new();
385 let headers = decoder.decode(&block).unwrap();
386
387 assert_eq!(headers[0].0, ":method");
389 assert_eq!(headers[1].0, ":authority");
390 assert_eq!(headers[2].0, ":scheme");
391 assert_eq!(headers[3].0, ":path");
392 }
393
394 #[test]
395 fn test_encoder_filters_connection_headers() {
396 let mut encoder = HpackEncoder::chrome();
397 let block = encoder.encode_request(
398 "GET",
399 "https",
400 "example.com",
401 "/",
402 &Headers::from(vec![
403 ("connection".to_string(), "keep-alive".to_string()),
404 ("keep-alive".to_string(), "timeout=5".to_string()),
405 ("user-agent".to_string(), "test".to_string()),
406 ]),
407 );
408
409 let mut decoder = HpackDecoder::new();
410 let headers = decoder.decode(&block).unwrap();
411
412 assert_eq!(headers.len(), 5);
414 assert_eq!(headers[4].0, "user-agent");
415 }
416}