static_serve_macro/
lib.rs

1//! Proc macro crate for compressing and embedding static assets
2//! in a web server
3//! Macro invocation: `embed_assets!('path/to/assets', compress = true);`
4
5use std::{
6    convert::Into,
7    fs,
8    io::{self, Write},
9    path::{Path, PathBuf},
10};
11
12use display_full_error::DisplayFullError;
13use flate2::write::GzEncoder;
14use glob::glob;
15use proc_macro2::{Span, TokenStream};
16use quote::{quote, ToTokens};
17use sha1::{Digest as _, Sha1};
18use syn::{
19    bracketed,
20    parse::{Parse, ParseStream},
21    parse_macro_input, Ident, LitBool, LitByteStr, LitStr, Token,
22};
23
24mod error;
25use error::{Error, GzipType, ZstdType};
26
27#[proc_macro]
28/// Embed and optionally compress static assets for a web server
29pub fn embed_assets(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30    let parsed = parse_macro_input!(input as EmbedAssets);
31    quote! { #parsed }.into()
32}
33
34struct EmbedAssets {
35    assets_dir: AssetsDir,
36    validated_ignore_dirs: IgnoreDirs,
37    should_compress: ShouldCompress,
38}
39
40impl Parse for EmbedAssets {
41    fn parse(input: ParseStream) -> syn::Result<Self> {
42        let assets_dir: AssetsDir = input.parse()?;
43
44        // Default to no compression
45        let mut maybe_should_compress = None;
46        let mut maybe_ignore_dirs = None;
47
48        while !input.is_empty() {
49            input.parse::<Token![,]>()?;
50            let key: Ident = input.parse()?;
51            input.parse::<Token![=]>()?;
52
53            match key.to_string().as_str() {
54                "compress" => {
55                    let value = input.parse()?;
56                    maybe_should_compress = Some(value);
57                }
58                "ignore_dirs" => {
59                    let value = input.parse()?;
60                    maybe_ignore_dirs = Some(value);
61                }
62                _ => {
63                    return Err(syn::Error::new(
64                        key.span(),
65                        "Unknown key in embed_assets! macro. Expected `compress` or `ignore_dirs`",
66                    ));
67                }
68            }
69        }
70
71        let should_compress = maybe_should_compress.unwrap_or_else(|| {
72            ShouldCompress(LitBool {
73                value: false,
74                span: Span::call_site(),
75            })
76        });
77
78        let ignore_dirs_with_span = maybe_ignore_dirs.unwrap_or(IgnoreDirsWithSpan(vec![]));
79        let validated_ignore_dirs = validate_ignore_dirs(ignore_dirs_with_span, &assets_dir.0)?;
80
81        Ok(Self {
82            assets_dir,
83            validated_ignore_dirs,
84            should_compress,
85        })
86    }
87}
88
89impl ToTokens for EmbedAssets {
90    fn to_tokens(&self, tokens: &mut TokenStream) {
91        let AssetsDir(assets_dir) = &self.assets_dir;
92        let ignore_dirs = &self.validated_ignore_dirs;
93        let ShouldCompress(should_compress) = &self.should_compress;
94
95        let result = generate_static_routes(assets_dir, ignore_dirs, should_compress);
96
97        match result {
98            Ok(value) => {
99                tokens.extend(quote! {
100                    #value
101                });
102            }
103            Err(err_message) => {
104                let error = syn::Error::new(Span::call_site(), err_message);
105                tokens.extend(error.to_compile_error());
106            }
107        }
108    }
109}
110
111struct AssetsDir(LitStr);
112
113impl Parse for AssetsDir {
114    fn parse(input: ParseStream) -> syn::Result<Self> {
115        let input_span = input.span();
116        let assets_dir: LitStr = input.parse()?;
117        let literal = assets_dir.value();
118        let path = Path::new(&literal);
119        let metadata = match fs::metadata(path) {
120            Ok(meta) => meta,
121            Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
122                return Err(syn::Error::new(
123                    input_span,
124                    "The specified assets directory does not exist",
125                ));
126            }
127            Err(e) => {
128                return Err(syn::Error::new(
129                    input_span,
130                    format!(
131                        "Error reading directory {literal}: {}",
132                        DisplayFullError(&e)
133                    ),
134                ));
135            }
136        };
137
138        if !metadata.is_dir() {
139            return Err(syn::Error::new(
140                input_span,
141                "The specified assets directory is not a directory",
142            ));
143        }
144
145        Ok(AssetsDir(assets_dir))
146    }
147}
148
149struct IgnoreDirs(Vec<PathBuf>);
150
151struct IgnoreDirsWithSpan(Vec<(PathBuf, Span)>);
152
153impl Parse for IgnoreDirsWithSpan {
154    fn parse(input: ParseStream) -> syn::Result<Self> {
155        let inner_content;
156        bracketed!(inner_content in input);
157
158        let mut dirs = Vec::new();
159        while !inner_content.is_empty() {
160            let directory_span = inner_content.span();
161            let directory_str = inner_content.parse::<LitStr>()?;
162            let path = PathBuf::from(directory_str.value());
163            dirs.push((path, directory_span));
164
165            if !inner_content.is_empty() {
166                inner_content.parse::<Token![,]>()?;
167            }
168        }
169
170        Ok(IgnoreDirsWithSpan(dirs))
171    }
172}
173
174fn validate_ignore_dirs(
175    ignore_dirs: IgnoreDirsWithSpan,
176    assets_dir: &LitStr,
177) -> syn::Result<IgnoreDirs> {
178    let mut valid_ignore_dirs = Vec::new();
179    for (dir, span) in ignore_dirs.0 {
180        let full_path = PathBuf::from(assets_dir.value()).join(&dir);
181        match fs::metadata(&full_path) {
182            Ok(meta) if !meta.is_dir() => {
183                return Err(syn::Error::new(
184                    span,
185                    "The specified ignored directory is not a directory",
186                ));
187            }
188            Ok(_) => valid_ignore_dirs.push(full_path),
189            Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
190                return Err(syn::Error::new(
191                    span,
192                    "The specified ignored directory does not exist",
193                ))
194            }
195            Err(e) => {
196                return Err(syn::Error::new(
197                    span,
198                    format!(
199                        "Error reading ignored directory {}: {}",
200                        dir.to_string_lossy(),
201                        DisplayFullError(&e)
202                    ),
203                ))
204            }
205        };
206    }
207    Ok(IgnoreDirs(valid_ignore_dirs))
208}
209
210struct ShouldCompress(LitBool);
211
212impl Parse for ShouldCompress {
213    fn parse(input: ParseStream) -> syn::Result<Self> {
214        let lit = input.parse()?;
215        Ok(ShouldCompress(lit))
216    }
217}
218
219fn generate_static_routes(
220    assets_dir: &LitStr,
221    ignore_dirs: &IgnoreDirs,
222    should_compress: &LitBool,
223) -> Result<TokenStream, error::Error> {
224    let assets_dir_abs = Path::new(&assets_dir.value())
225        .canonicalize()
226        .map_err(Error::CannotCanonicalizeDirectory)?;
227    let assets_dir_abs_str = assets_dir_abs
228        .to_str()
229        .ok_or(Error::InvalidUnicodeInDirectoryName)?;
230    let canon_ignore_dirs = ignore_dirs
231        .0
232        .iter()
233        .map(|d| d.canonicalize().map_err(Error::CannotCanonicalizeIgnoreDir))
234        .collect::<Result<Vec<_>, _>>()?;
235
236    let mut routes = Vec::new();
237    for entry in glob(&format!("{assets_dir_abs_str}/**/*")).map_err(Error::Pattern)? {
238        let entry = entry.map_err(Error::Glob)?;
239        let metadata = entry.metadata().map_err(Error::CannotGetMetadata)?;
240        if metadata.is_dir() {
241            continue;
242        }
243
244        // Skip `entry`s which are located in ignored subdirectories
245        if canon_ignore_dirs
246            .iter()
247            .any(|ignore_dir| entry.starts_with(ignore_dir))
248        {
249            continue;
250        }
251
252        let contents = fs::read(&entry).map_err(Error::CannotReadEntryContents)?;
253
254        // Optionally compress files
255        let (maybe_gzip, maybe_zstd) = if should_compress.value {
256            let gzip = gzip_compress(&contents)?;
257            let zstd = zstd_compress(&contents)?;
258            (gzip, zstd)
259        } else {
260            (None, None)
261        };
262
263        // Create parameters for `::static_serve::static_route()`
264        let entry_path = entry
265            .to_str()
266            .ok_or(Error::InvalidUnicodeInEntryName)?
267            .strip_prefix(assets_dir_abs_str)
268            .unwrap_or_default();
269        let content_type = file_content_type(&entry)?;
270        let etag_str = etag(&contents);
271        let lit_byte_str_contents = LitByteStr::new(&contents, Span::call_site());
272        let maybe_gzip = option_to_token_stream_option(maybe_gzip.as_ref());
273        let maybe_zstd = option_to_token_stream_option(maybe_zstd.as_ref());
274
275        routes.push(quote! {
276            router = ::static_serve::static_route(
277                router,
278                #entry_path,
279                #content_type,
280                #etag_str,
281                #lit_byte_str_contents,
282                #maybe_gzip,
283                #maybe_zstd,
284            );
285        });
286    }
287
288    Ok(quote! {
289    pub fn static_router<S>() -> ::axum::Router<S>
290        where S: ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static {
291            let mut router = ::axum::Router::<S>::new();
292            #(#routes)*
293            router
294        }
295    })
296}
297
298fn gzip_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
299    let mut compressor = GzEncoder::new(Vec::new(), flate2::Compression::best());
300    compressor
301        .write_all(contents)
302        .map_err(|e| Error::Gzip(GzipType::CompressorWrite(e)))?;
303    let compressed = compressor
304        .finish()
305        .map_err(|e| Error::Gzip(GzipType::EncoderFinish(e)))?;
306
307    Ok(maybe_get_compressed(&compressed, contents))
308}
309
310fn zstd_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
311    let level = *zstd::compression_level_range().end();
312    let mut encoder = zstd::Encoder::new(Vec::new(), level).unwrap();
313    write_to_zstd_encoder(&mut encoder, contents)
314        .map_err(|e| Error::Zstd(ZstdType::EncoderWrite(e)))?;
315
316    let compressed = encoder
317        .finish()
318        .map_err(|e| Error::Zstd(ZstdType::EncoderFinish(e)))?;
319
320    Ok(maybe_get_compressed(&compressed, contents))
321}
322
323fn write_to_zstd_encoder(
324    encoder: &mut zstd::Encoder<'static, Vec<u8>>,
325    contents: &[u8],
326) -> io::Result<()> {
327    encoder.set_pledged_src_size(Some(
328        contents
329            .len()
330            .try_into()
331            .expect("contents size should fit into u64"),
332    ))?;
333    encoder.window_log(23)?;
334    encoder.include_checksum(false)?;
335    encoder.include_contentsize(false)?;
336    encoder.long_distance_matching(false)?;
337    encoder.write_all(contents)?;
338
339    Ok(())
340}
341
342fn option_to_token_stream_option<T: ToTokens>(opt: Option<&T>) -> TokenStream {
343    if let Some(inner) = opt {
344        quote! { ::std::option::Option::Some(#inner) }
345    } else {
346        quote! { ::std::option::Option::None }
347    }
348}
349
350fn is_compression_significant(compressed_len: usize, contents_len: usize) -> bool {
351    let ninety_pct_original = contents_len / 10 * 9;
352    compressed_len < ninety_pct_original
353}
354
355fn maybe_get_compressed(compressed: &[u8], contents: &[u8]) -> Option<LitByteStr> {
356    is_compression_significant(compressed.len(), contents.len())
357        .then(|| LitByteStr::new(compressed, Span::call_site()))
358}
359
360fn file_content_type(path: &Path) -> Result<&'static str, error::Error> {
361    match path.extension() {
362        Some(ext) if ext.eq_ignore_ascii_case("css") => Ok("text/css"),
363        Some(ext) if ext.eq_ignore_ascii_case("js") => Ok("text/javascript"),
364        Some(ext) if ext.eq_ignore_ascii_case("txt") => Ok("text/plain"),
365        Some(ext) if ext.eq_ignore_ascii_case("woff") => Ok("font/woff"),
366        Some(ext) if ext.eq_ignore_ascii_case("woff2") => Ok("font/woff2"),
367        Some(ext) if ext.eq_ignore_ascii_case("svg") => Ok("image/svg+xml"),
368        ext => Err(error::Error::UnknownFileExtension(ext.map(Into::into))),
369    }
370}
371
372fn etag(contents: &[u8]) -> String {
373    let sha256 = Sha1::digest(contents);
374    let hash = u64::from_le_bytes(sha256[..8].try_into().unwrap())
375        ^ u64::from_le_bytes(sha256[8..16].try_into().unwrap());
376    format!("\"{hash:016x}\"")
377}