1use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
55use base64::Engine as _;
56use http::{
57 header::{self, HeaderValue},
58 Request, Response, StatusCode,
59};
60use http_body::Body;
61use std::{fmt, marker::PhantomData};
62
63const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
64
65impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
66 pub fn basic(inner: S, username: &str, value: &str) -> Self
74 where
75 ResBody: Body + Default,
76 {
77 Self::custom(inner, Basic::new(username, value))
78 }
79}
80
81impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
82 pub fn basic(username: &str, password: &str) -> Self
90 where
91 ResBody: Body + Default,
92 {
93 Self::custom(Basic::new(username, password))
94 }
95}
96
97impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
98 pub fn bearer(inner: S, token: &str) -> Self
106 where
107 ResBody: Body + Default,
108 {
109 Self::custom(inner, Bearer::new(token))
110 }
111}
112
113impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
114 pub fn bearer(token: &str) -> Self
122 where
123 ResBody: Body + Default,
124 {
125 Self::custom(Bearer::new(token))
126 }
127}
128
129pub struct Bearer<ResBody> {
133 header_value: HeaderValue,
134 _ty: PhantomData<fn() -> ResBody>,
135}
136
137impl<ResBody> Bearer<ResBody> {
138 fn new(token: &str) -> Self
139 where
140 ResBody: Body + Default,
141 {
142 Self {
143 header_value: format!("Bearer {}", token)
144 .parse()
145 .expect("token is not a valid header value"),
146 _ty: PhantomData,
147 }
148 }
149}
150
151impl<ResBody> Clone for Bearer<ResBody> {
152 fn clone(&self) -> Self {
153 Self {
154 header_value: self.header_value.clone(),
155 _ty: PhantomData,
156 }
157 }
158}
159
160impl<ResBody> fmt::Debug for Bearer<ResBody> {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 f.debug_struct("Bearer")
163 .field("header_value", &self.header_value)
164 .finish()
165 }
166}
167
168impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
169where
170 ResBody: Body + Default,
171{
172 type ResponseBody = ResBody;
173
174 fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
175 match request.headers().get(header::AUTHORIZATION) {
176 Some(actual) if actual == self.header_value => Ok(()),
177 _ => {
178 let mut res = Response::new(ResBody::default());
179 *res.status_mut() = StatusCode::UNAUTHORIZED;
180 Err(res)
181 }
182 }
183 }
184}
185
186pub struct Basic<ResBody> {
190 header_value: HeaderValue,
191 _ty: PhantomData<fn() -> ResBody>,
192}
193
194impl<ResBody> Basic<ResBody> {
195 fn new(username: &str, password: &str) -> Self
196 where
197 ResBody: Body + Default,
198 {
199 let encoded = BASE64.encode(format!("{}:{}", username, password));
200 let header_value = format!("Basic {}", encoded).parse().unwrap();
201 Self {
202 header_value,
203 _ty: PhantomData,
204 }
205 }
206}
207
208impl<ResBody> Clone for Basic<ResBody> {
209 fn clone(&self) -> Self {
210 Self {
211 header_value: self.header_value.clone(),
212 _ty: PhantomData,
213 }
214 }
215}
216
217impl<ResBody> fmt::Debug for Basic<ResBody> {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 f.debug_struct("Basic")
220 .field("header_value", &self.header_value)
221 .finish()
222 }
223}
224
225impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
226where
227 ResBody: Body + Default,
228{
229 type ResponseBody = ResBody;
230
231 fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
232 match request.headers().get(header::AUTHORIZATION) {
233 Some(actual) if actual == self.header_value => Ok(()),
234 _ => {
235 let mut res = Response::new(ResBody::default());
236 *res.status_mut() = StatusCode::UNAUTHORIZED;
237 res.headers_mut()
238 .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
239 Err(res)
240 }
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 #[allow(unused_imports)]
248 use super::*;
249
250 use crate::test_helpers::Body;
251 use crate::validate_request::ValidateRequestHeaderLayer;
252
253 use http::header;
254 use tower_async::{BoxError, ServiceBuilder};
255 use tower_async_service::Service;
256
257 #[tokio::test]
258 async fn valid_basic_token() {
259 let service = ServiceBuilder::new()
260 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
261 .service_fn(echo);
262
263 let request = Request::get("/")
264 .header(
265 header::AUTHORIZATION,
266 format!("Basic {}", BASE64.encode("foo:bar")),
267 )
268 .body(Body::empty())
269 .unwrap();
270
271 let res = service.call(request).await.unwrap();
272
273 assert_eq!(res.status(), StatusCode::OK);
274 }
275
276 #[tokio::test]
277 async fn invalid_basic_token() {
278 let service = ServiceBuilder::new()
279 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
280 .service_fn(echo);
281
282 let request = Request::get("/")
283 .header(
284 header::AUTHORIZATION,
285 format!("Basic {}", BASE64.encode("wrong:credentials")),
286 )
287 .body(Body::empty())
288 .unwrap();
289
290 let res = service.call(request).await.unwrap();
291
292 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
293
294 let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
295 assert_eq!(www_authenticate, "Basic");
296 }
297
298 #[tokio::test]
299 async fn valid_bearer_token() {
300 let service = ServiceBuilder::new()
301 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
302 .service_fn(echo);
303
304 let request = Request::get("/")
305 .header(header::AUTHORIZATION, "Bearer foobar")
306 .body(Body::empty())
307 .unwrap();
308
309 let res = service.call(request).await.unwrap();
310
311 assert_eq!(res.status(), StatusCode::OK);
312 }
313
314 #[tokio::test]
315 async fn basic_auth_is_case_sensitive_in_prefix() {
316 let service = ServiceBuilder::new()
317 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
318 .service_fn(echo);
319
320 let request = Request::get("/")
321 .header(
322 header::AUTHORIZATION,
323 format!("basic {}", BASE64.encode("foo:bar")),
324 )
325 .body(Body::empty())
326 .unwrap();
327
328 let res = service.call(request).await.unwrap();
329
330 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
331 }
332
333 #[tokio::test]
334 async fn basic_auth_is_case_sensitive_in_value() {
335 let service = ServiceBuilder::new()
336 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
337 .service_fn(echo);
338
339 let request = Request::get("/")
340 .header(
341 header::AUTHORIZATION,
342 format!("Basic {}", BASE64.encode("Foo:bar")),
343 )
344 .body(Body::empty())
345 .unwrap();
346
347 let res = service.call(request).await.unwrap();
348
349 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
350 }
351
352 #[tokio::test]
353 async fn invalid_bearer_token() {
354 let service = ServiceBuilder::new()
355 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
356 .service_fn(echo);
357
358 let request = Request::get("/")
359 .header(header::AUTHORIZATION, "Bearer wat")
360 .body(Body::empty())
361 .unwrap();
362
363 let res = service.call(request).await.unwrap();
364
365 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
366 }
367
368 #[tokio::test]
369 async fn bearer_token_is_case_sensitive_in_prefix() {
370 let service = ServiceBuilder::new()
371 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
372 .service_fn(echo);
373
374 let request = Request::get("/")
375 .header(header::AUTHORIZATION, "bearer foobar")
376 .body(Body::empty())
377 .unwrap();
378
379 let res = service.call(request).await.unwrap();
380
381 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
382 }
383
384 #[tokio::test]
385 async fn bearer_token_is_case_sensitive_in_token() {
386 let service = ServiceBuilder::new()
387 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
388 .service_fn(echo);
389
390 let request = Request::get("/")
391 .header(header::AUTHORIZATION, "Bearer Foobar")
392 .body(Body::empty())
393 .unwrap();
394
395 let res = service.call(request).await.unwrap();
396
397 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
398 }
399
400 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
401 Ok(Response::new(req.into_body()))
402 }
403}