Skip to main content

tower_embed_core/
service.rs

1//! Tower services for serving embedded assets.
2
3use std::{
4    convert::Infallible,
5    marker::PhantomData,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use crate::{Body, ResponseFuture};
11
12/// The trait used to access to embedded assets.
13pub trait Embed {
14    /// Forward an HTTP request to the embedded asset service.
15    fn forward(
16        req: http::Request<()>,
17    ) -> impl Future<Output = http::Response<Body>> + Send + 'static;
18}
19
20/// Extension trait for [`Embed`].
21pub trait EmbedExt: Embed + Sized {
22    /// Returns a service that serves a custom not found page.
23    fn not_found_page(path: &str) -> NotFoundPage<Self> {
24        NotFoundPage::new(path)
25    }
26}
27
28impl<T> EmbedExt for T where T: Embed + Sized {}
29
30type NotFoundService = tower::util::BoxCloneSyncService<(), http::Response<Body>, Infallible>;
31
32/// Service that serves files from embedded assets.
33pub struct ServeEmbed<E = ()> {
34    _embed: PhantomData<E>,
35    /// Fallback service for handling 404 Not Found errors.
36    not_found_service: Option<NotFoundService>,
37}
38
39impl<E> Clone for ServeEmbed<E> {
40    fn clone(&self) -> Self {
41        Self {
42            _embed: PhantomData,
43            not_found_service: self.not_found_service.clone(),
44        }
45    }
46}
47
48impl<E: Embed> Default for ServeEmbed<E> {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl<E: Embed> ServeEmbed<E> {
55    /// Create a new [`ServeEmbed`] service.
56    pub fn new() -> Self {
57        Self {
58            _embed: PhantomData,
59            not_found_service: None,
60        }
61    }
62
63    /// Set the fallback service for not found pages.
64    pub fn with_not_found<S>(mut self, service: S) -> Self
65    where
66        S: tower::Service<(), Response = http::Response<Body>, Error = Infallible>
67            + Send
68            + Sync
69            + Clone
70            + 'static,
71        S::Future: Send + 'static,
72    {
73        self.not_found_service = Some(tower::util::BoxCloneSyncService::new(service));
74        self
75    }
76}
77
78impl<E, ReqBody> tower::Service<http::Request<ReqBody>> for ServeEmbed<E>
79where
80    E: Embed + Send + 'static,
81{
82    type Response = http::Response<Body>;
83    type Error = std::convert::Infallible;
84    type Future = ResponseFuture;
85
86    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        Poll::Ready(Ok(()))
88    }
89
90    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
91        let req = req.map(|_| ());
92        let mut not_found_service = self.not_found_service.clone();
93
94        ResponseFuture::new(async move {
95            use tower::ServiceExt;
96
97            let response =
98                if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
99                    crate::response::method_not_allowed()
100                } else {
101                    let mut response = E::forward(req).await;
102                    if let Some(not_found_service) = not_found_service.take()
103                        && response.status() == http::StatusCode::NOT_FOUND
104                    {
105                        let service = not_found_service.ready_oneshot().await.unwrap();
106                        response = service.oneshot(()).await.unwrap()
107                    }
108
109                    response
110                };
111            Ok(response)
112        })
113    }
114}
115
116/// A service that serves a custom not found page.
117pub struct NotFoundPage<E>(Arc<NotFoundPageInner<E>>);
118
119impl<E> Clone for NotFoundPage<E> {
120    fn clone(&self) -> Self {
121        Self(Arc::clone(&self.0))
122    }
123}
124
125struct NotFoundPageInner<E> {
126    _embed: PhantomData<E>,
127    page: String,
128}
129
130impl<E> NotFoundPage<E> {
131    pub(crate) fn new(page: &str) -> Self {
132        let page = if page.starts_with('/') {
133            page.to_string()
134        } else {
135            format!("/{}", page)
136        };
137
138        Self(Arc::new(NotFoundPageInner {
139            _embed: PhantomData,
140            page,
141        }))
142    }
143}
144
145impl<E> tower::Service<()> for NotFoundPage<E>
146where
147    E: Embed,
148{
149    type Response = http::Response<Body>;
150    type Error = Infallible;
151    type Future = ResponseFuture;
152
153    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        Poll::Ready(Ok(()))
155    }
156
157    fn call(&mut self, _: ()) -> Self::Future {
158        let req = http::Request::builder()
159            .method(http::Method::GET)
160            .uri(&self.0.page)
161            .body(())
162            .unwrap();
163        ResponseFuture::new(async move {
164            let mut response = E::forward(req).await;
165            response.headers_mut().remove(http::header::ETAG);
166            response.headers_mut().remove(http::header::LAST_MODIFIED);
167
168            Ok(response)
169        })
170    }
171}