1use std::panic::{AssertUnwindSafe, RefUnwindSafe, UnwindSafe};
28
29use http::{uri::PathAndQuery, Uri};
30
31use crate::{request::LogLevel, unit::UnitService, UnitError, UnitResult};
32
33pub use http::{Request, Response};
34
35pub trait HttpService: UnwindSafe {
42 fn handle_request(
43 &self,
44 _req: Request<Vec<u8>>,
45 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error>>;
46}
47
48pub struct HttpHandler<H: HttpService>(H);
54
55impl<H: HttpService> HttpHandler<H> {
56 pub fn new(unit_service: H) -> Self {
57 Self(unit_service)
58 }
59}
60
61impl<H: HttpService + RefUnwindSafe> UnitService for HttpHandler<H> {
62 fn handle_request(&mut self, mut req: crate::request::Request<'_>) -> UnitResult<()> {
63 self.handle_request_with_http(&mut req).map_err(|err| {
64 req.log(LogLevel::Error, err.to_string());
65 UnitError::error()
66 })
67 }
68}
69
70impl<H: HttpService + RefUnwindSafe> HttpHandler<H> {
71 fn handle_request_with_http(
72 &self,
73 req: &mut crate::request::Request<'_>,
74 ) -> Result<(), Box<dyn std::error::Error>> {
75 let path_and_query: PathAndQuery = req.target().parse()?;
76 let uri = Uri::builder()
77 .scheme(if req.tls() { "https" } else { "http" })
78 .authority(req.server_name())
79 .path_and_query(path_and_query)
80 .build()?;
81 let mut http_request_builder = Request::builder();
82
83 for (name, value) in req.fields() {
84 http_request_builder = http_request_builder.header(name, value);
85 }
86
87 let http_request = http_request_builder
88 .uri(uri)
89 .method(req.method())
90 .body(req.body().read_to_vec()?)?;
91
92 let http_request = AssertUnwindSafe(http_request);
96 let handler = &self.0;
97
98 let http_response = std::panic::catch_unwind(move || {
99 let http_request = http_request;
100 handler.handle_request(http_request.0)
101 });
102
103 match http_response {
104 Ok(Ok(http_response)) => {
105 let header_count = http_response.headers().len();
106 let headers_size: usize = http_response
107 .headers()
108 .iter()
109 .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
110 .sum();
111 let body_size = http_response.body().len();
112
113 let response = req.create_response(
114 http_response.status().as_u16(),
115 header_count,
116 headers_size + body_size,
117 )?;
118
119 for (name, value) in http_response.headers() {
120 response.add_field(name, value)?;
121 }
122 response.add_content(http_response.body())?;
123 response.send()?;
124 }
125 Ok(Err(err)) => {
126 let content_type = ("Content-Type", "text/plain");
127 let response_body = format!("The server experienced an internal error: {}", err);
128 let response = req.create_response(
129 501,
130 1,
131 content_type.0.len() + content_type.1.len() + response_body.len(),
132 )?;
133 response.add_field(content_type.0, content_type.1)?;
134 response.add_content(response_body)?;
135 response.send()?;
136 }
137 Err(panic_payload) => {
138 req.log(LogLevel::Error, "Panicked during http request handling.");
139
140 std::panic::resume_unwind(panic_payload);
143 }
144 }
145 Ok(())
146 }
147}
148
149impl<F> HttpService for F
150where
151 F: Fn(Request<Vec<u8>>) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error>>,
152 F: UnwindSafe + 'static,
153{
154 fn handle_request(
155 &self,
156 req: Request<Vec<u8>>,
157 ) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error>> {
158 self(req)
159 }
160}