1use lambda_http::{
2 http::{
3 header::{HeaderValue, IntoHeaderName},
4 StatusCode,
5 },
6 Body,
7};
8
9use crate::{error, handler::Handler};
10
11pub type Middleware<E = error::Error> = Box<dyn Fn(Handler<E>) -> Handler<E>>;
13
14pub fn header<E: 'static, N, V>(header_name: N, header_value: V) -> Middleware<E>
16where
17 N: IntoHeaderName + Copy + 'static,
18 V: Into<String> + Copy + 'static,
19{
20 Box::new(move |handler| {
21 Box::new(move |request, context| {
22 let mut response = handler(request, context)?;
23 if let Ok(value) = HeaderValue::from_str(header_value.into().as_str()) {
24 response.headers_mut().insert(header_name, value);
25 }
26 Ok(response)
27 })
28 })
29}
30
31pub fn status<E: 'static>(code: StatusCode) -> Middleware<E> {
33 Box::new(move |handler| {
34 Box::new(move |request, context| {
35 let mut response = handler(request, context)?;
36 *response.status_mut() = code;
37 Ok(response)
38 })
39 })
40}
41
42pub fn body<E: 'static, B>(body: B) -> Middleware<E>
44where
45 B: Into<Body> + Copy + 'static,
46{
47 Box::new(move |handler| {
48 Box::new(move |request, context| {
49 let mut response = handler(request, context)?;
50 *response.body_mut() = body.into();
51 Ok(response)
52 })
53 })
54}
55
56#[cfg(test)]
57mod tests {
58 use std::fmt::Debug;
59
60 use lambda_http::{http::StatusCode, Body, Request, Response};
61 use lambda_runtime::Context;
62
63 use crate::handler::{default_handler, WrappingHandler};
64
65 use super::*;
66
67 fn handler_resp<E: Debug>(handler: Handler<E>) -> Response<Body> {
68 let request = Request::default();
69 let context = Context::default();
70 handler(request, context).unwrap()
71 }
72
73 #[test]
74 fn test_header() {
75 let handler = default_handler::<error::Error>()
76 .wrap_with(header("x-foo", "bar"))
77 .handler();
78 let resp = handler_resp(handler);
79 assert_eq!(
80 resp.headers().get("x-foo"),
81 Some(&HeaderValue::from_static("bar"))
82 );
83 }
84
85 #[test]
86 fn test_status() {
87 let handler = default_handler::<error::Error>()
88 .wrap_with(status(StatusCode::CREATED))
89 .handler();
90 let resp = handler_resp(handler);
91 assert_eq!(resp.status(), StatusCode::CREATED);
92 }
93
94 #[test]
95 fn test_body() {
96 let handler = default_handler::<error::Error>()
97 .wrap_with(body("foo"))
98 .handler();
99 let resp = handler_resp(handler);
100 assert_eq!(*resp.body(), Body::Text("foo".to_string()));
101 }
102}