Skip to main content

tower_embed/
lib.rs

1//! This crate provides a [`tower`] service designed to provide embedded static
2//! assets support for web application. This service includes the following HTTP features:
3//!
4//! - Support for GET and HEAD requests
5//! - `Content-Type` header generation based on file MIME type guessed from extension.
6//! - `ETag` header generation and validation.
7//! - `Last-Modified` header generation and validation.
8//! - Customizable 404 page.
9//!
10//! In `debug` mode, assets are served directly from the filesystem to facilitate rapid
11//! development. Both `ETag` and `Last-Modified` headers are not generated in this mode.
12//!
13//! # Usage
14//!
15//! ```no_run
16//! use axum::Router;
17//! use tower_embed::{Embed, EmbedExt, ServeEmbed};
18//!
19//! #[derive(Embed)]
20//! #[embed(folder = "assets")]
21//! struct Assets;
22//!
23//! #[tokio::main]
24//! async fn main() {
25//!     let assets = ServeEmbed::builder()
26//!         .not_found_service(Assets::not_found_page("404.html"))
27//!         .build::<Assets>();
28//!     let router = Router::new().fallback_service(assets);
29//!
30//!     let listener = tokio::net::TcpListener::bind("127.0.0.1:8080")
31//!         .await
32//!         .unwrap();
33//!     axum::serve::serve(listener, router).await.unwrap();
34//! }
35//! ```
36//!
37//! Please see the [examples] directory for working examples.
38//!
39//! [`tower`]: https://crates.io/crates/tower
40//! [examples]: https://github.com/mattiapenati/tower-embed/tree/main/examples
41
42#[cfg(not(feature = "tokio"))]
43compile_error!("Only tokio runtime is supported, and it is required to use `tower-embed`.");
44
45use std::{
46    convert::Infallible,
47    marker::PhantomData,
48    pin::Pin,
49    sync::Arc,
50    task::{Context, Poll},
51};
52
53#[doc(inline)]
54pub use tower_embed_impl::Embed;
55
56#[doc(inline)]
57pub use tower_embed_core as core;
58
59#[doc(inline)]
60pub use tower_embed_core::{Embed, http::Body};
61
62#[doc(hidden)]
63pub mod file;
64
65/// Response future of [`ServeEmbed`]
66pub struct ResponseFuture(ResponseFutureInner);
67
68type ResponseFutureInner =
69    Pin<Box<dyn Future<Output = Result<http::Response<Body>, Infallible>> + Send>>;
70
71impl ResponseFuture {
72    fn new<F>(future: F) -> Self
73    where
74        F: Future<Output = Result<http::Response<Body>, Infallible>> + Send + 'static,
75    {
76        ResponseFuture(Box::pin(future))
77    }
78}
79
80impl Future for ResponseFuture {
81    type Output = Result<http::Response<Body>, Infallible>;
82
83    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        self.0.as_mut().poll(cx)
85    }
86}
87
88/// Service that serves files from embedded assets.
89pub struct ServeEmbed<E = ()> {
90    _embed: PhantomData<E>,
91    /// Fallback service for handling 404 Not Found errors.
92    not_found_service: Option<NotFoundService>,
93}
94
95type NotFoundService =
96    tower::util::BoxCloneSyncService<http::Request<()>, http::Response<Body>, Infallible>;
97
98impl<E> Clone for ServeEmbed<E> {
99    fn clone(&self) -> Self {
100        Self {
101            _embed: PhantomData,
102            not_found_service: self.not_found_service.clone(),
103        }
104    }
105}
106
107impl<E: Embed> Default for ServeEmbed<E> {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl<E: Embed> ServeEmbed<E> {
114    /// Create a new [`ServeEmbed`] service.
115    pub fn new() -> Self {
116        ServeEmbedBuilder::new().build::<E>()
117    }
118}
119
120impl ServeEmbed<()> {
121    /// Create a new [`ServeEmbedBuilder`] to customize a new service instance.
122    pub fn builder() -> ServeEmbedBuilder {
123        ServeEmbedBuilder::new()
124    }
125}
126
127impl<E, ReqBody> tower::Service<http::Request<ReqBody>> for ServeEmbed<E>
128where
129    E: Embed + Send + 'static,
130{
131    type Response = http::Response<Body>;
132    type Error = std::convert::Infallible;
133    type Future = ResponseFuture;
134
135    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136        Poll::Ready(Ok(()))
137    }
138
139    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
140        let req = req.map(|_| ());
141        let not_found_service = self.not_found_service.clone();
142        ResponseFuture::new(async move {
143            let response =
144                if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
145                    method_not_allowed()
146                } else {
147                    let path = req.uri().path().trim_start_matches('/');
148                    tracing::trace!("Serving embedded resource '{path}'");
149                    handle_request(E::get(path), req, not_found_service).await
150                };
151            Ok(response)
152        })
153    }
154}
155
156/// Builder for [`ServeEmbed`] service.
157#[derive(Default)]
158pub struct ServeEmbedBuilder {
159    not_found_service: Option<NotFoundService>,
160}
161
162impl ServeEmbedBuilder {
163    /// Create a new [`ServeEmbedBuilder`].
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    /// Set the fallback service.
169    pub fn not_found_service<S>(mut self, service: S) -> Self
170    where
171        S: tower::Service<http::Request<()>, Response = http::Response<Body>, Error = Infallible>
172            + Send
173            + Sync
174            + Clone
175            + 'static,
176        S::Future: Send + 'static,
177    {
178        self.not_found_service = Some(tower::util::BoxCloneSyncService::new(service));
179        self
180    }
181
182    /// Build the [`ServeEmbed`] service.
183    pub fn build<E: Embed>(self) -> ServeEmbed<E> {
184        ServeEmbed {
185            _embed: PhantomData,
186            not_found_service: self.not_found_service,
187        }
188    }
189}
190
191/// Extension trait for [`Embed`].
192pub trait EmbedExt: Embed + Sized {
193    /// Returns a service that serves a custom not found page.
194    fn not_found_page(path: &str) -> NotFoundPage<Self> {
195        NotFoundPage::new(path.to_string())
196    }
197}
198
199impl<T> EmbedExt for T where T: Embed + Sized {}
200
201/// A service that serves a custom not found page.
202pub struct NotFoundPage<E>(Arc<NotFoundPageInner<E>>);
203
204impl<E> Clone for NotFoundPage<E> {
205    fn clone(&self) -> Self {
206        Self(Arc::clone(&self.0))
207    }
208}
209
210struct NotFoundPageInner<E> {
211    _embed: PhantomData<E>,
212    page: String,
213}
214
215impl<E> NotFoundPage<E> {
216    fn new(page: String) -> Self {
217        Self(Arc::new(NotFoundPageInner {
218            _embed: PhantomData,
219            page,
220        }))
221    }
222}
223
224impl<E> tower::Service<http::Request<()>> for NotFoundPage<E>
225where
226    E: Embed,
227{
228    type Response = http::Response<Body>;
229    type Error = Infallible;
230    type Future = ResponseFuture;
231
232    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233        Poll::Ready(Ok(()))
234    }
235
236    fn call(&mut self, req: http::Request<()>) -> Self::Future {
237        let embedded = E::get(&self.0.page);
238        ResponseFuture::new(async move { Ok(handle_request(embedded, req, None).await) })
239    }
240}
241
242async fn handle_request<F>(
243    embedded: F,
244    request: http::Request<()>,
245    not_found_service: Option<NotFoundService>,
246) -> http::Response<Body>
247where
248    F: Future<Output = std::io::Result<core::Embedded>> + Send,
249{
250    use core::headers::{self, HeaderMapExt};
251
252    let path = request.uri().path().trim_start_matches('/');
253    let core::Embedded { content, metadata } = match embedded.await {
254        Ok(embedded) => embedded,
255        Err(err)
256            if err.kind() == std::io::ErrorKind::NotFound
257                || err.kind() == std::io::ErrorKind::NotADirectory =>
258        {
259            tracing::trace!("Embedded resource not found: '{path}'");
260            return not_found_response(request, not_found_service).await;
261        }
262        Err(err) => {
263            tracing::error!("Failed to get embedded resource '{path}': {err}");
264            return server_error_response(err);
265        }
266    };
267
268    let if_none_match = request.headers().typed_get::<headers::IfNoneMatch>();
269    if let Some(if_none_match) = if_none_match
270        && let Some(etag) = &metadata.etag
271        && !if_none_match.condition_passes(etag)
272    {
273        tracing::trace!("ETag match for embedded resource '{path}'");
274        return not_modified_response();
275    }
276
277    let if_modified_since = request.headers().typed_get::<headers::IfModifiedSince>();
278    if let Some(if_modified_since) = if_modified_since
279        && let Some(last_modified) = &metadata.last_modified
280        && !if_modified_since.condition_passes(last_modified)
281    {
282        tracing::trace!("Last-Modified match for embedded resource '{path}'");
283        return not_modified_response();
284    }
285
286    let mut response = http::Response::builder()
287        .status(http::StatusCode::OK)
288        .body(Body::stream(content))
289        .unwrap();
290
291    response.headers_mut().typed_insert(metadata.content_type);
292    if let Some(etag) = metadata.etag {
293        response.headers_mut().typed_insert(etag);
294    }
295    if let Some(last_modified) = metadata.last_modified {
296        response.headers_mut().typed_insert(last_modified);
297    }
298
299    response
300}
301
302async fn not_found_response(
303    request: http::Request<()>,
304    mut not_found_service: Option<NotFoundService>,
305) -> http::Response<Body> {
306    use tower::ServiceExt;
307
308    let mut response = match not_found_service.take() {
309        Some(service) => {
310            let service = service.ready_oneshot().await.unwrap();
311            service.oneshot(request).await.unwrap()
312        }
313        None => http::Response::builder()
314            .status(http::StatusCode::NOT_FOUND)
315            .body(Body::empty())
316            .unwrap(),
317    };
318    response.headers_mut().insert(
319        http::header::CACHE_CONTROL,
320        http::HeaderValue::from_static("no-store"),
321    );
322    response
323}
324
325fn not_modified_response() -> http::Response<Body> {
326    http::Response::builder()
327        .status(http::StatusCode::NOT_MODIFIED)
328        .body(Body::empty())
329        .unwrap()
330}
331
332fn method_not_allowed() -> http::Response<Body> {
333    http::Response::builder()
334        .header(
335            http::header::ALLOW,
336            http::HeaderValue::from_static("GET, HEAD"),
337        )
338        .status(http::StatusCode::METHOD_NOT_ALLOWED)
339        .body(Body::empty())
340        .unwrap()
341}
342
343fn server_error_response(_err: std::io::Error) -> http::Response<Body> {
344    http::Response::builder()
345        .status(http::StatusCode::INTERNAL_SERVER_ERROR)
346        .header(http::header::CACHE_CONTROL, "no-store")
347        .body(Body::empty())
348        .unwrap()
349}