tower_embed_core/
service.rs1use std::{
4 convert::Infallible,
5 marker::PhantomData,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::{Body, ResponseFuture};
11
12pub trait Embed {
14 fn forward(
16 req: http::Request<()>,
17 ) -> impl Future<Output = http::Response<Body>> + Send + 'static;
18}
19
20pub trait EmbedExt: Embed + Sized {
22 fn not_found_page(path: &str) -> NotFoundPage<Self> {
24 NotFoundPage::new(path)
25 }
26}
27
28impl<T> EmbedExt for T where T: Embed + Sized {}
29
30type NotFoundService = tower::util::BoxCloneSyncService<(), http::Response<Body>, Infallible>;
31
32pub struct ServeEmbed<E = ()> {
34 _embed: PhantomData<E>,
35 not_found_service: Option<NotFoundService>,
37}
38
39impl<E> Clone for ServeEmbed<E> {
40 fn clone(&self) -> Self {
41 Self {
42 _embed: PhantomData,
43 not_found_service: self.not_found_service.clone(),
44 }
45 }
46}
47
48impl<E: Embed> Default for ServeEmbed<E> {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl<E: Embed> ServeEmbed<E> {
55 pub fn new() -> Self {
57 Self {
58 _embed: PhantomData,
59 not_found_service: None,
60 }
61 }
62
63 pub fn with_not_found<S>(mut self, service: S) -> Self
65 where
66 S: tower::Service<(), Response = http::Response<Body>, Error = Infallible>
67 + Send
68 + Sync
69 + Clone
70 + 'static,
71 S::Future: Send + 'static,
72 {
73 self.not_found_service = Some(tower::util::BoxCloneSyncService::new(service));
74 self
75 }
76}
77
78impl<E, ReqBody> tower::Service<http::Request<ReqBody>> for ServeEmbed<E>
79where
80 E: Embed + Send + 'static,
81{
82 type Response = http::Response<Body>;
83 type Error = std::convert::Infallible;
84 type Future = ResponseFuture;
85
86 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 Poll::Ready(Ok(()))
88 }
89
90 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
91 let req = req.map(|_| ());
92 let mut not_found_service = self.not_found_service.clone();
93
94 ResponseFuture::new(async move {
95 use tower::ServiceExt;
96
97 let response =
98 if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
99 crate::response::method_not_allowed()
100 } else {
101 let mut response = E::forward(req).await;
102 if let Some(not_found_service) = not_found_service.take()
103 && response.status() == http::StatusCode::NOT_FOUND
104 {
105 let service = not_found_service.ready_oneshot().await.unwrap();
106 response = service.oneshot(()).await.unwrap()
107 }
108
109 response
110 };
111 Ok(response)
112 })
113 }
114}
115
116pub struct NotFoundPage<E>(Arc<NotFoundPageInner<E>>);
118
119impl<E> Clone for NotFoundPage<E> {
120 fn clone(&self) -> Self {
121 Self(Arc::clone(&self.0))
122 }
123}
124
125struct NotFoundPageInner<E> {
126 _embed: PhantomData<E>,
127 page: String,
128}
129
130impl<E> NotFoundPage<E> {
131 pub(crate) fn new(page: &str) -> Self {
132 let page = if page.starts_with('/') {
133 page.to_string()
134 } else {
135 format!("/{}", page)
136 };
137
138 Self(Arc::new(NotFoundPageInner {
139 _embed: PhantomData,
140 page,
141 }))
142 }
143}
144
145impl<E> tower::Service<()> for NotFoundPage<E>
146where
147 E: Embed,
148{
149 type Response = http::Response<Body>;
150 type Error = Infallible;
151 type Future = ResponseFuture;
152
153 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154 Poll::Ready(Ok(()))
155 }
156
157 fn call(&mut self, _: ()) -> Self::Future {
158 let req = http::Request::builder()
159 .method(http::Method::GET)
160 .uri(&self.0.page)
161 .body(())
162 .unwrap();
163 ResponseFuture::new(async move {
164 let mut response = E::forward(req).await;
165 response.headers_mut().remove(http::header::ETAG);
166 response.headers_mut().remove(http::header::LAST_MODIFIED);
167
168 Ok(response)
169 })
170 }
171}