Skip to main content

static_serve/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::convert::Infallible;
4
5use axum::{
6    Router,
7    extract::FromRequestParts,
8    http::{
9        StatusCode,
10        header::{
11            ACCEPT_ENCODING, ACCEPT_RANGES, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_TYPE, ETAG,
12            HeaderValue, IF_NONE_MATCH, VARY,
13        },
14        request::Parts,
15    },
16    response::IntoResponse,
17    routing::{MethodRouter, get},
18};
19use bytes::Bytes;
20use range_requests::{
21    headers::{if_range::IfRange, range::HttpRange},
22    serve_file_with_http_range,
23};
24
25pub use static_serve_macro::{embed_asset, embed_assets};
26
27/// The accept/reject status for gzip and zstd encoding
28#[derive(Debug, Copy, Clone)]
29struct AcceptEncoding {
30    /// Is gzip accepted?
31    pub gzip: bool,
32    /// Is zstd accepted?
33    pub zstd: bool,
34}
35
36impl<S> FromRequestParts<S> for AcceptEncoding
37where
38    S: Send + Sync,
39{
40    type Rejection = Infallible;
41
42    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
43        let accept_encoding = parts.headers.get(ACCEPT_ENCODING);
44        let accept_encoding = accept_encoding
45            .and_then(|accept_encoding| accept_encoding.to_str().ok())
46            .unwrap_or_default();
47
48        Ok(Self {
49            gzip: accept_encoding.contains("gzip"),
50            zstd: accept_encoding.contains("zstd"),
51        })
52    }
53}
54
55/// Check if the  `IfNoneMatch` header is present
56#[derive(Debug)]
57struct IfNoneMatch(Option<HeaderValue>);
58
59impl IfNoneMatch {
60    /// required function for checking if `IfNoneMatch` is present
61    fn matches(&self, etag: &str) -> bool {
62        self.0
63            .as_ref()
64            .is_some_and(|if_none_match| if_none_match.as_bytes() == etag.as_bytes())
65    }
66}
67
68impl<S> FromRequestParts<S> for IfNoneMatch
69where
70    S: Send + Sync,
71{
72    type Rejection = Infallible;
73
74    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
75        let if_none_match = parts.headers.get(IF_NONE_MATCH).cloned();
76        Ok(Self(if_none_match))
77    }
78}
79
80#[doc(hidden)]
81#[expect(clippy::too_many_arguments)]
82/// The router for adding routes for static assets
83pub fn static_route<S>(
84    router: Router<S>,
85    web_path: &'static str,
86    content_type: &'static str,
87    etag: &'static str,
88    body: &'static [u8],
89    body_gz: Option<&'static [u8]>,
90    body_zst: Option<&'static [u8]>,
91    cache_busted: bool,
92) -> Router<S>
93where
94    S: Clone + Send + Sync + 'static,
95{
96    router.route(
97        web_path,
98        get(
99            move |accept_encoding: AcceptEncoding,
100                  if_none_match: IfNoneMatch,
101                  http_range: Option<HttpRange>,
102                  if_range: Option<IfRange>| async move {
103                static_inner(StaticInnerData {
104                    content_type,
105                    etag,
106                    body,
107                    body_gz,
108                    body_zst,
109                    cache_busted,
110                    accept_encoding,
111                    if_none_match,
112                    http_range,
113                    if_range,
114                })
115            },
116        ),
117    )
118}
119
120#[doc(hidden)]
121/// Creates a route for a single static asset.
122///
123/// Used by the `embed_asset!` macro, so it needs to be `pub`.
124pub fn static_method_router<S>(
125    content_type: &'static str,
126    etag: &'static str,
127    body: &'static [u8],
128    body_gz: Option<&'static [u8]>,
129    body_zst: Option<&'static [u8]>,
130    cache_busted: bool,
131) -> MethodRouter<S>
132where
133    S: Clone + Send + Sync + 'static,
134{
135    MethodRouter::get(
136        MethodRouter::new(),
137        move |accept_encoding: AcceptEncoding,
138              if_none_match: IfNoneMatch,
139              http_range: Option<HttpRange>,
140              if_range: Option<IfRange>| async move {
141            static_inner(StaticInnerData {
142                content_type,
143                etag,
144                body,
145                body_gz,
146                body_zst,
147                cache_busted,
148                accept_encoding,
149                if_none_match,
150                http_range,
151                if_range,
152            })
153        },
154    )
155}
156
157/// Struct of parameters for `static_inner` (to avoid `clippy::too_many_arguments`)
158///
159/// This differs from `StaticRouteData` because it
160/// includes the `AcceptEncoding` and `IfNoneMatch` fields
161/// and excludes the `web_path`
162struct StaticInnerData {
163    content_type: &'static str,
164    etag: &'static str,
165    body: &'static [u8],
166    body_gz: Option<&'static [u8]>,
167    body_zst: Option<&'static [u8]>,
168    cache_busted: bool,
169    accept_encoding: AcceptEncoding,
170    if_none_match: IfNoneMatch,
171    http_range: Option<HttpRange>,
172    if_range: Option<IfRange>,
173}
174
175fn static_inner(static_inner_data: StaticInnerData) -> impl IntoResponse {
176    let StaticInnerData {
177        content_type,
178        etag,
179        body,
180        body_gz,
181        body_zst,
182        cache_busted,
183        accept_encoding,
184        if_none_match,
185        http_range,
186        if_range,
187    } = static_inner_data;
188
189    let optional_cache_control = if cache_busted {
190        Some([(
191            CACHE_CONTROL,
192            HeaderValue::from_static("public, max-age=31536000, immutable"),
193        )])
194    } else {
195        None
196    };
197
198    let resp_base = (
199        [
200            (CONTENT_TYPE, HeaderValue::from_static(content_type)),
201            (ETAG, HeaderValue::from_static(etag)),
202            (VARY, HeaderValue::from_static("Accept-Encoding")),
203        ],
204        optional_cache_control,
205    );
206
207    if if_none_match.matches(etag) {
208        return (resp_base, StatusCode::NOT_MODIFIED).into_response();
209    }
210
211    let resp_base = (
212        [(ACCEPT_RANGES, HeaderValue::from_static("bytes"))],
213        resp_base,
214    );
215
216    let http_range = match (http_range, if_range) {
217        (Some(range), Some(if_range)) => {
218            let etag_value = HeaderValue::from_static(etag);
219            if_range.evaluate(range, None, Some(&etag_value))
220        }
221        (range, _) => range,
222    };
223
224    let (selected_body, optional_content_encoding) = match (
225        (accept_encoding.gzip, body_gz),
226        (accept_encoding.zstd, body_zst),
227        &http_range,
228    ) {
229        (_, (true, Some(body_zst)), None) => (
230            Bytes::from_static(body_zst),
231            Some([(CONTENT_ENCODING, HeaderValue::from_static("zstd"))]),
232        ),
233        ((true, Some(body_gz)), _, None) => (
234            Bytes::from_static(body_gz),
235            Some([(CONTENT_ENCODING, HeaderValue::from_static("gzip"))]),
236        ),
237        _ => (Bytes::from_static(body), None),
238    };
239
240    match serve_file_with_http_range(selected_body, http_range) {
241        Ok(body_range) => (resp_base, optional_content_encoding, body_range).into_response(),
242        Err(unsatisfiable) => (resp_base, unsatisfiable).into_response(),
243    }
244}