1use core::{fmt, ops::Deref};
4
5use crate::{
6 body::ResponseBody,
7 context::WebContext,
8 error::{Error, HeaderNotFound},
9 handler::{FromRequest, Responder},
10 http::{
11 WebResponse,
12 header::{self, HeaderMap, HeaderName, HeaderValue},
13 },
14};
15
16macro_rules! const_header_name {
17 ($n:expr ;) => {};
18 ($n:expr ; $i: ident $(, $rest:ident)*) => {
19 pub const $i: usize = $n;
20 const_header_name!($n + 1; $($rest),*);
21 };
22 ($($i:ident), +) => { const_header_name!(0; $($i),*); };
23}
24
25macro_rules! map_to_header_name {
26 ($($i:ident), +) => {
27 const fn map_to_header_name<const HEADER_NAME: usize>() -> header::HeaderName {
28 match HEADER_NAME {
29 $(
30 $i => header::$i,
31 )*
32 _ => unreachable!()
33 }
34 }
35 }
36}
37
38macro_rules! const_header_name_impl {
39 ($($i:ident), +) => {
40 const_header_name!($($i), +);
41 map_to_header_name!($($i), +);
42 }
43}
44
45const_header_name_impl!(ACCEPT, ACCEPT_ENCODING, HOST, CONTENT_TYPE, CONTENT_LENGTH);
46
47pub struct HeaderRef<'a, const HEADER_NAME: usize>(&'a HeaderValue);
78
79impl<const HEADER_NAME: usize> fmt::Debug for HeaderRef<'_, HEADER_NAME> {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 f.debug_struct("Header")
82 .field("name", &map_to_header_name::<HEADER_NAME>())
83 .field("value", &self.0)
84 .finish()
85 }
86}
87
88impl<const HEADER_NAME: usize> Deref for HeaderRef<'_, HEADER_NAME> {
89 type Target = HeaderValue;
90
91 fn deref(&self) -> &Self::Target {
92 self.0
93 }
94}
95
96impl<'a, 'r, C, B, const HEADER_NAME: usize> FromRequest<'a, WebContext<'r, C, B>> for HeaderRef<'a, HEADER_NAME> {
97 type Type<'b> = HeaderRef<'b, HEADER_NAME>;
98 type Error = Error;
99
100 #[inline]
101 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
102 let name = map_to_header_name::<HEADER_NAME>();
103 ctx.req()
104 .headers()
105 .get(&name)
106 .map(HeaderRef)
107 .ok_or_else(|| Error::from_service(HeaderNotFound(name)))
108 }
109}
110
111impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for &'a HeaderMap {
112 type Type<'b> = &'b HeaderMap;
113 type Error = Error;
114
115 #[inline]
116 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
117 Ok(ctx.req().headers())
118 }
119}
120
121impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for HeaderMap {
122 type Type<'b> = HeaderMap;
123 type Error = Error;
124
125 #[inline]
126 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
127 Ok(ctx.req().headers().clone())
128 }
129}
130
131impl<'r, C, B> Responder<WebContext<'r, C, B>> for (HeaderName, HeaderValue) {
132 type Response = WebResponse;
133 type Error = Error;
134
135 async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
136 let res = ctx.into_response(ResponseBody::empty());
137 Responder::<WebContext<'r, C, B>>::map(self, res)
138 }
139
140 fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
141 res.headers_mut().append(self.0, self.1);
142 Ok(res)
143 }
144}
145
146impl<'r, C, B, const N: usize> Responder<WebContext<'r, C, B>> for [(HeaderName, HeaderValue); N] {
147 type Response = WebResponse;
148 type Error = Error;
149
150 async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
151 let res = ctx.into_response(ResponseBody::empty());
152 Responder::<WebContext<'r, C, B>>::map(self, res)
153 }
154
155 fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
156 for (k, v) in self {
157 res.headers_mut().append(k, v);
158 }
159 Ok(res)
160 }
161}
162
163impl<'r, C, B> Responder<WebContext<'r, C, B>> for Vec<(HeaderName, HeaderValue)> {
164 type Response = WebResponse;
165 type Error = Error;
166
167 async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
168 let res = ctx.into_response(ResponseBody::empty());
169 Responder::<WebContext<'r, C, B>>::map(self, res)
170 }
171
172 fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
173 for (k, v) in self {
174 res.headers_mut().append(k, v);
175 }
176 Ok(res)
177 }
178}
179
180impl<'r, C, B> Responder<WebContext<'r, C, B>> for HeaderMap {
181 type Response = WebResponse;
182 type Error = Error;
183
184 async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
185 let res = ctx.into_response(ResponseBody::empty());
186 Responder::<WebContext<'r, C, B>>::map(self, res)
187 }
188
189 fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
190 res.headers_mut().extend(self);
191 Ok(res)
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use xitca_unsafe_collection::futures::NowOrPanic;
198
199 use super::*;
200
201 #[test]
202 fn extract_header() {
203 let mut req = WebContext::new_test(());
204 let mut req = req.as_web_ctx();
205 req.req_mut()
206 .headers_mut()
207 .insert(header::HOST, header::HeaderValue::from_static("996"));
208 req.req_mut()
209 .headers_mut()
210 .insert(header::ACCEPT_ENCODING, header::HeaderValue::from_static("251"));
211
212 assert_eq!(
213 HeaderRef::<'_, { super::ACCEPT_ENCODING }>::from_request(&req)
214 .now_or_panic()
215 .unwrap()
216 .deref(),
217 &header::HeaderValue::from_static("251")
218 );
219 assert_eq!(
220 HeaderRef::<'_, { super::HOST }>::from_request(&req)
221 .now_or_panic()
222 .unwrap()
223 .deref(),
224 &header::HeaderValue::from_static("996")
225 );
226 }
227}