1use crate::error::{Error, ErrorKind, ResultExt};
9use std::{
10 borrow::Cow, collections::HashSet, convert::Infallible, fmt, str::FromStr, sync::LazyLock,
11};
12
13pub static DEFAULT_ALLOWED_HEADER_NAMES: LazyLock<HashSet<Cow<'static, str>>> =
15 LazyLock::new(|| {
16 [
17 "accept",
18 "cache-control",
19 "connection",
20 "content-length",
21 "content-type",
22 "date",
23 "etag",
24 "expires",
25 "if-match",
26 "if-modified-since",
27 "if-none-match",
28 "if-unmodified-since",
29 "last-modified",
30 "ms-cv",
31 "pragma",
32 "request-id",
33 "retry-after",
34 "server",
35 "traceparent",
36 "tracestate",
37 "transfer-encoding",
38 "user-agent",
39 "www-authenticate",
40 "x-ms-request-id",
41 "x-ms-client-request-id",
42 "x-ms-return-client-request-id",
43 ]
44 .iter()
45 .map(|s| Cow::Borrowed(*s))
46 .collect()
47 });
48
49pub trait AsHeaders {
51 type Error: std::error::Error + Send + Sync + 'static;
53
54 type Iter: Iterator<Item = (HeaderName, HeaderValue)>;
56
57 fn as_headers(&self) -> Result<Self::Iter, Self::Error>;
59}
60
61impl<T> AsHeaders for T
62where
63 T: Header,
64{
65 type Error = Infallible;
66 type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>;
67
68 fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
70 Ok(vec![(self.name(), self.value())].into_iter())
71 }
72}
73
74impl<T> AsHeaders for Option<T>
75where
76 T: AsHeaders<Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>>,
77{
78 type Error = T::Error;
79 type Iter = T::Iter;
80
81 fn as_headers(&self) -> Result<Self::Iter, T::Error> {
83 match self {
84 Some(h) => h.as_headers(),
85 None => Ok(vec![].into_iter()),
86 }
87 }
88}
89
90pub trait FromHeaders: Sized {
94 type Error: std::error::Error + Send + Sync + 'static;
96
97 fn header_names() -> &'static [&'static str];
101
102 fn from_headers(headers: &Headers) -> Result<Option<Self>, Self::Error>;
109}
110
111pub trait Header {
121 fn name(&self) -> HeaderName;
123 fn value(&self) -> HeaderValue;
125}
126
127#[derive(Clone, PartialEq, Eq, Default)]
129pub struct Headers(std::collections::HashMap<HeaderName, HeaderValue>);
130
131impl Headers {
132 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn get<H: FromHeaders>(&self) -> crate::Result<H> {
139 match H::from_headers(self) {
140 Ok(Some(x)) => Ok(x),
141 Ok(None) => Err(crate::Error::with_message_fn(
142 ErrorKind::DataConversion,
143 || {
144 let required_headers = H::header_names();
145 format!(
146 "required header(s) not found: {}",
147 required_headers.join(", ")
148 )
149 },
150 )),
151 Err(e) => Err(crate::Error::new(ErrorKind::DataConversion, e)),
152 }
153 }
154
155 pub fn get_optional<H: FromHeaders>(&self) -> Result<Option<H>, H::Error> {
162 H::from_headers(self)
163 }
164
165 pub fn get_optional_string(&self, key: &HeaderName) -> Option<String> {
167 self.get_as(key).ok()
168 }
169
170 pub fn get_str(&self, key: &HeaderName) -> crate::Result<&str> {
172 self.get_with(key, |s| crate::Result::Ok(s.as_str()))
173 }
174
175 pub fn get_optional_str(&self, key: &HeaderName) -> Option<&str> {
177 self.get_str(key).ok()
178 }
179
180 pub fn get_as<V, E>(&self, key: &HeaderName) -> crate::Result<V>
182 where
183 V: FromStr<Err = E>,
184 E: std::error::Error + Send + Sync + 'static,
185 {
186 self.get_with(key, |s| s.as_str().parse())
187 }
188
189 pub fn get_optional_as<V, E>(&self, key: &HeaderName) -> crate::Result<Option<V>>
191 where
192 V: FromStr<Err = E>,
193 E: std::error::Error + Send + Sync + 'static,
194 {
195 self.get_optional_with(key, |s| s.as_str().parse())
196 }
197
198 pub fn get_with<'a, V, F, E>(&'a self, key: &HeaderName, parser: F) -> crate::Result<V>
200 where
201 F: FnOnce(&'a HeaderValue) -> Result<V, E>,
202 E: std::error::Error + Send + Sync + 'static,
203 {
204 self.get_optional_with(key, parser)?.ok_or_else(|| {
205 Error::with_message_fn(ErrorKind::DataConversion, || {
206 format!("header not found {}", key.as_str())
207 })
208 })
209 }
210
211 pub fn get_optional_with<'a, V, F, E>(
213 &'a self,
214 key: &HeaderName,
215 parser: F,
216 ) -> crate::Result<Option<V>>
217 where
218 F: FnOnce(&'a HeaderValue) -> Result<V, E>,
219 E: std::error::Error + Send + Sync + 'static,
220 {
221 self.0
222 .get(key)
223 .map(|v: &HeaderValue| {
224 parser(v).with_context_fn(ErrorKind::DataConversion, || {
225 let ty = std::any::type_name::<V>();
226 format!("unable to parse header '{key:?}: {v:?}' into {ty}",)
227 })
228 })
229 .transpose()
230 }
231
232 pub fn insert<K, V>(&mut self, key: K, value: V)
234 where
235 K: Into<HeaderName>,
236 V: Into<HeaderValue>,
237 {
238 self.0.insert(key.into(), value.into());
239 }
240
241 pub fn add<H>(&mut self, header: H) -> Result<(), H::Error>
249 where
250 H: AsHeaders,
251 {
252 for (key, value) in header.as_headers()? {
253 self.insert(key, value);
254 }
255 Ok(())
256 }
257
258 pub fn iter(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
260 self.0.iter()
261 }
262
263 pub fn remove<K>(&mut self, key: K) -> Option<HeaderValue>
265 where
266 K: Into<HeaderName>,
267 {
268 self.0.remove(&key.into())
269 }
270}
271
272impl fmt::Debug for Headers {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_map()
276 .entries(self.0.iter().map(|(k, v)| {
277 (
278 k.as_str(),
279 if DEFAULT_ALLOWED_HEADER_NAMES.contains(k.as_str()) {
280 v.as_str()
281 } else {
282 super::REDACTED_PATTERN
283 },
284 )
285 }))
286 .finish()
287 }
288}
289
290impl IntoIterator for Headers {
291 type Item = (HeaderName, HeaderValue);
292
293 type IntoIter = std::collections::hash_map::IntoIter<HeaderName, HeaderValue>;
294
295 fn into_iter(self) -> Self::IntoIter {
296 self.0.into_iter()
297 }
298}
299
300impl From<std::collections::HashMap<HeaderName, HeaderValue>> for Headers {
301 fn from(c: std::collections::HashMap<HeaderName, HeaderValue>) -> Self {
302 Self(c)
303 }
304}
305
306#[derive(Clone, Debug, Eq, PartialOrd, Ord)]
308pub struct HeaderName {
309 name: Cow<'static, str>,
311
312 pub(crate) is_standard: bool,
315}
316
317impl HeaderName {
318 pub const fn from_static(s: &'static str) -> Self {
320 ensure_no_uppercase(s);
321 Self {
322 name: Cow::Borrowed(s),
323 is_standard: false,
324 }
325 }
326
327 pub const fn from_static_standard(s: &'static str) -> Self {
329 ensure_no_uppercase(s);
330 Self {
331 name: Cow::Borrowed(s),
332 is_standard: true,
333 }
334 }
335
336 fn from_cow<C>(c: C) -> Self
337 where
338 C: Into<Cow<'static, str>>,
339 {
340 let c = c.into();
341 assert!(
342 c.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()),
343 "header names must be lowercase: {c}"
344 );
345 Self {
346 name: c,
347 is_standard: false,
348 }
349 }
350
351 pub fn as_str(&self) -> &str {
353 self.name.as_ref()
354 }
355
356 pub fn is_standard(&self) -> bool {
358 self.is_standard
359 }
360}
361
362impl PartialEq for HeaderName {
363 fn eq(&self, other: &Self) -> bool {
364 self.name.eq_ignore_ascii_case(&other.name)
365 }
366}
367
368impl PartialEq<&str> for HeaderName {
369 fn eq(&self, other: &&str) -> bool {
370 self.name.eq_ignore_ascii_case(other)
371 }
372}
373
374impl std::hash::Hash for HeaderName {
375 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
376 std::hash::Hash::hash(&self.name, state);
378 }
379}
380
381const fn ensure_no_uppercase(s: &str) {
383 let bytes = s.as_bytes();
384 let mut i = 0;
385 while i < bytes.len() {
386 let byte = bytes[i];
387 assert!(
388 !(byte >= 65u8 && byte <= 90u8),
389 "header names must not contain uppercase letters"
390 );
391 i += 1;
392 }
393}
394
395impl From<&'static str> for HeaderName {
396 fn from(s: &'static str) -> Self {
397 Self::from_cow(s)
398 }
399}
400
401impl From<String> for HeaderName {
402 fn from(s: String) -> Self {
403 Self::from_cow(s.to_lowercase())
404 }
405}
406
407#[derive(Clone, PartialEq, Eq)]
409pub struct HeaderValue(Cow<'static, str>);
410
411impl HeaderValue {
412 pub const fn from_static(s: &'static str) -> Self {
414 Self(Cow::Borrowed(s))
415 }
416
417 pub fn from_cow<C>(c: C) -> Self
419 where
420 C: Into<Cow<'static, str>>,
421 {
422 Self(c.into())
423 }
424
425 pub fn as_str(&self) -> &str {
427 self.0.as_ref()
428 }
429}
430
431impl fmt::Debug for HeaderValue {
432 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
433 f.write_str("HeaderValue")
434 }
435}
436
437impl From<&'static str> for HeaderValue {
438 fn from(s: &'static str) -> Self {
439 Self::from_cow(s)
440 }
441}
442
443impl From<String> for HeaderValue {
444 fn from(s: String) -> Self {
445 Self::from_cow(s)
446 }
447}
448
449impl From<&String> for HeaderValue {
450 fn from(s: &String) -> Self {
451 s.clone().into()
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use crate::error::ErrorKind;
458 use url::Url;
459
460 use super::{FromHeaders, HeaderName, Headers};
461
462 #[derive(Debug)]
464 struct ContentLocationForTest(Url);
465
466 impl FromHeaders for ContentLocationForTest {
467 type Error = url::ParseError;
468
469 fn header_names() -> &'static [&'static str] {
470 &["content-location"]
471 }
472
473 fn from_headers(headers: &super::Headers) -> Result<Option<Self>, Self::Error> {
474 let Some(loc) = headers.get_optional_str(&HeaderName::from("content-location")) else {
475 return Ok(None);
476 };
477
478 Ok(Some(ContentLocationForTest(loc.parse()?)))
479 }
480 }
481
482 #[test]
483 pub fn headers_get_optional_returns_ok_some_if_header_present_and_valid() {
484 let mut headers = Headers::new();
485 headers.insert("content-location", "https://example.com");
486 let content_location: ContentLocationForTest = headers.get_optional().unwrap().unwrap();
487 assert_eq!("https://example.com/", content_location.0.as_str())
488 }
489
490 #[test]
491 pub fn headers_get_optional_returns_ok_none_if_header_not_present() {
492 let headers = Headers::new();
493 let content_location: Option<ContentLocationForTest> = headers.get_optional().unwrap();
494 assert!(content_location.is_none())
495 }
496
497 #[test]
498 pub fn headers_get_optional_returns_err_if_conversion_fails() {
499 let mut headers = Headers::new();
500 headers.insert("content-location", "not a URL");
501 let err = headers
502 .get_optional::<ContentLocationForTest>()
503 .unwrap_err();
504 assert_eq!(url::ParseError::RelativeUrlWithoutBase, err)
505 }
506
507 #[test]
508 pub fn headers_get_returns_ok_if_header_present_and_valid() {
509 let mut headers = Headers::new();
510 headers.insert("content-location", "https://example.com");
511 let content_location: ContentLocationForTest = headers.get().unwrap();
512 assert_eq!("https://example.com/", content_location.0.as_str())
513 }
514
515 #[test]
516 pub fn headers_get_returns_err_if_header_not_present() {
517 let headers = Headers::new();
518 let err = headers.get::<ContentLocationForTest>().unwrap_err();
519 assert_eq!(&ErrorKind::DataConversion, err.kind());
520
521 assert_eq!(
523 "required header(s) not found: content-location",
524 format!("{}", err)
525 );
526 }
527
528 #[test]
529 pub fn headers_get_returns_err_if_header_requiring_multiple_headers_not_present() {
530 #[derive(Debug)]
531 struct HasTwoHeaders;
532
533 impl FromHeaders for HasTwoHeaders {
534 type Error = std::convert::Infallible;
535
536 fn header_names() -> &'static [&'static str] {
537 &["header-a", "header-b"]
538 }
539
540 fn from_headers(_: &Headers) -> Result<Option<Self>, Self::Error> {
541 Ok(None)
542 }
543 }
544
545 let headers = Headers::new();
546 let err = headers.get::<HasTwoHeaders>().unwrap_err();
547 assert_eq!(&ErrorKind::DataConversion, err.kind());
548
549 assert_eq!(
551 "required header(s) not found: header-a, header-b",
552 format!("{}", err)
553 );
554 }
555
556 #[test]
557 pub fn headers_get_returns_err_if_conversion_fails() {
558 let mut headers = Headers::new();
559 headers.insert("content-location", "not a URL");
560 let err = headers.get::<ContentLocationForTest>().unwrap_err();
561 assert_eq!(&ErrorKind::DataConversion, err.kind());
562 let inner: Box<url::ParseError> = err.into_inner().unwrap().downcast().unwrap();
563 assert_eq!(Box::new(url::ParseError::RelativeUrlWithoutBase), inner)
564 }
565
566 #[test]
567 pub fn headers_remove_existing_header_returns_value() {
568 let mut headers = Headers::new();
569 headers.insert("test-header", "test-value");
570
571 assert_eq!(
573 headers.get_optional_str(&HeaderName::from("test-header")),
574 Some("test-value")
575 );
576
577 let removed_value = headers.remove("test-header");
579 assert!(removed_value.is_some());
580 assert_eq!(removed_value.unwrap().as_str(), "test-value");
581
582 assert_eq!(
584 headers.get_optional_str(&HeaderName::from("test-header")),
585 None
586 );
587 }
588
589 #[test]
590 pub fn headers_remove_nonexistent_header_returns_none() {
591 let mut headers = Headers::new();
592
593 let removed_value = headers.remove("nonexistent-header");
595 assert_eq!(removed_value, None);
596 }
597
598 #[test]
599 pub fn headers_remove_works_with_different_key_types() {
600 let mut headers = Headers::new();
601 headers.insert("test-header", "test-value");
602
603 let removed_value = headers.remove("test-header");
605 assert!(removed_value.is_some());
606 assert_eq!(removed_value.unwrap().as_str(), "test-value");
607
608 headers.insert("test-header", "test-value");
610
611 let removed_value = headers.remove(HeaderName::from("test-header"));
613 assert!(removed_value.is_some());
614 assert_eq!(removed_value.unwrap().as_str(), "test-value");
615
616 headers.insert("test-header", "test-value");
618
619 let removed_value = headers.remove("test-header".to_string());
621 assert!(removed_value.is_some());
622 assert_eq!(removed_value.unwrap().as_str(), "test-value");
623 }
624}