1use std::{
2 borrow::Cow,
3 pin::Pin,
4 task::{Context, Poll, ready},
5};
6
7use bytes::Bytes;
8use futures_core::{Stream, future::BoxFuture};
9use http_body::{Body, Frame};
10use http_body_util::BodyExt;
11use tower_embed_core::{BoxError, Embed, Embedded, headers};
12
13use crate::core::headers::HeaderMapExt;
14
15type BoxBody = http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>;
16
17#[derive(Debug)]
19pub struct ResponseBody(BoxBody);
20
21impl ResponseBody {
22 pub fn empty() -> Self {
24 ResponseBody::new(http_body_util::Empty::new())
25 }
26
27 pub fn full(data: Bytes) -> Self {
29 ResponseBody::new(http_body_util::Full::new(data))
30 }
31
32 pub fn stream<S, E>(stream: S) -> Self
34 where
35 S: Stream<Item = Result<Frame<Bytes>, E>> + Send + 'static,
36 E: Into<BoxError>,
37 {
38 ResponseBody::new(http_body_util::StreamBody::new(stream))
39 }
40
41 fn new<B>(body: B) -> Self
42 where
43 B: Body<Data = Bytes> + Send + 'static,
44 B::Error: Into<BoxError>,
45 {
46 ResponseBody(body.map_err(|err| err.into()).boxed_unsync())
47 }
48}
49
50impl http_body::Body for ResponseBody {
51 type Data = Bytes;
52 type Error = BoxError;
53
54 fn poll_frame(
55 mut self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
58 Pin::new(&mut self.0).poll_frame(cx)
59 }
60
61 fn is_end_stream(&self) -> bool {
62 self.0.is_end_stream()
63 }
64
65 fn size_hint(&self) -> http_body::SizeHint {
66 self.0.size_hint()
67 }
68}
69
70pub struct ResponseFuture {
74 inner: ResponseFutureInner,
75}
76
77enum ResponseFutureInner {
78 Ready(Option<http::Response<ResponseBody>>),
79 WaitingEmbedded {
80 fut: BoxFuture<'static, std::io::Result<Embedded>>,
81 if_none_match: Option<headers::IfNoneMatch>,
82 if_modified_since: Option<headers::IfModifiedSince>,
83 },
84}
85
86impl ResponseFuture {
87 pub(crate) fn new<E, B>(req: &http::Request<B>) -> Self
88 where
89 E: Embed,
90 {
91 if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
92 return Self::method_not_allowed();
93 }
94
95 let path = get_file_path_from_uri(req.uri());
96 let embedded = E::get(path.as_ref());
97
98 let if_none_match = req.headers().typed_get::<headers::IfNoneMatch>();
99 let if_modified_since = req.headers().typed_get::<headers::IfModifiedSince>();
100
101 let inner = ResponseFutureInner::WaitingEmbedded {
102 fut: Box::pin(embedded),
103 if_none_match,
104 if_modified_since,
105 };
106 Self { inner }
107 }
108
109 pub(crate) fn method_not_allowed() -> Self {
110 let response = http::Response::builder()
111 .header(
112 http::header::ALLOW,
113 http::HeaderValue::from_static("GET, HEAD"),
114 )
115 .status(http::StatusCode::METHOD_NOT_ALLOWED)
116 .body(ResponseBody::empty())
117 .unwrap();
118
119 Self {
120 inner: ResponseFutureInner::Ready(Some(response)),
121 }
122 }
123}
124
125impl Future for ResponseFuture {
126 type Output = Result<http::Response<ResponseBody>, std::convert::Infallible>;
127
128 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129 let inner = &mut self.get_mut().inner;
130
131 let response = match inner {
132 ResponseFutureInner::Ready(response) => response
133 .take()
134 .expect("ResponseFuture polled after completion"),
135 ResponseFutureInner::WaitingEmbedded {
136 fut,
137 if_none_match,
138 if_modified_since,
139 } => match ready!(Pin::new(fut).poll(cx)) {
140 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
141 *inner = ResponseFutureInner::Ready(None);
142 http::Response::builder()
143 .status(http::StatusCode::NOT_FOUND)
144 .body(ResponseBody::empty())
145 .unwrap()
146 }
147 Err(_) => {
148 *inner = ResponseFutureInner::Ready(None);
149 http::Response::builder()
150 .status(http::StatusCode::INTERNAL_SERVER_ERROR)
151 .body(ResponseBody::empty())
152 .unwrap()
153 }
154 Ok(embedded) => {
155 if let Some(if_none_match) = if_none_match
157 && let Some(etag) = &embedded.metadata.etag
158 && !if_none_match.condition_passes(etag)
159 {
160 return Poll::Ready(Ok(http::Response::builder()
161 .status(http::StatusCode::NOT_MODIFIED)
162 .body(ResponseBody::empty())
163 .unwrap()));
164 }
165
166 if let Some(if_modified_since) = if_modified_since
168 && let Some(last_modified) = embedded.metadata.last_modified
169 && !if_modified_since.condition_passes(&last_modified)
170 {
171 return Poll::Ready(Ok(http::Response::builder()
172 .status(http::StatusCode::NOT_MODIFIED)
173 .body(ResponseBody::empty())
174 .unwrap()));
175 }
176
177 let Embedded { content, metadata } = embedded;
178 let mut response = http::Response::builder()
179 .status(http::StatusCode::OK)
180 .body(ResponseBody::stream(content))
181 .unwrap();
182
183 response.headers_mut().typed_insert(metadata.content_type);
184 if let Some(etag) = metadata.etag {
185 response.headers_mut().typed_insert(etag);
186 }
187 if let Some(last_modified) = metadata.last_modified {
188 response.headers_mut().typed_insert(last_modified);
189 }
190
191 response
192 }
193 },
194 };
195 Poll::Ready(Ok(response))
196 }
197}
198
199fn get_file_path_from_uri(uri: &http::Uri) -> Cow<'_, str> {
200 let path = uri.path();
201 if path.ends_with("/") {
202 Cow::Owned(format!("{}index.html", path.trim_start_matches('/')))
203 } else {
204 Cow::Borrowed(path.trim_start_matches('/'))
205 }
206}