static_serve_macro/
lib.rs

1//! Proc macro crate for compressing and embedding static assets
2//! in a web server
3
4use std::{
5    convert::Into,
6    fs,
7    io::{self, Write},
8    path::{Path, PathBuf},
9};
10
11use display_full_error::DisplayFullError;
12use flate2::write::GzEncoder;
13use glob::glob;
14use proc_macro2::{Span, TokenStream};
15use quote::{quote, ToTokens};
16use sha1::{Digest as _, Sha1};
17use syn::{
18    bracketed,
19    parse::{Parse, ParseStream},
20    parse_macro_input, Ident, LitBool, LitByteStr, LitStr, Token,
21};
22
23mod error;
24use error::{Error, GzipType, ZstdType};
25
26#[proc_macro]
27/// Embed and optionally compress static assets for a web server
28pub fn embed_assets(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    let parsed = parse_macro_input!(input as EmbedAssets);
30    quote! { #parsed }.into()
31}
32
33#[proc_macro]
34/// Embed and optionally compress a single static asset for a web server
35pub fn embed_asset(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36    let parsed = parse_macro_input!(input as EmbedAsset);
37    quote! { #parsed }.into()
38}
39
40struct EmbedAsset {
41    asset_file: AssetFile,
42    should_compress: ShouldCompress,
43}
44
45struct AssetFile(LitStr);
46
47impl Parse for EmbedAsset {
48    fn parse(input: ParseStream) -> syn::Result<Self> {
49        let asset_file: AssetFile = input.parse()?;
50
51        // Default to no compression
52        let mut maybe_should_compress = None;
53
54        while !input.is_empty() {
55            input.parse::<Token![,]>()?;
56            let key: Ident = input.parse()?;
57            input.parse::<Token![=]>()?;
58
59            if matches!(key.to_string().as_str(), "compress") {
60                let value = input.parse()?;
61                maybe_should_compress = Some(value);
62            } else {
63                return Err(syn::Error::new(
64                    key.span(),
65                    format!(
66                        "Unknown key in `embed_asset!` macro. Expected `compress` but got {key}"
67                    ),
68                ));
69            }
70        }
71
72        let should_compress = maybe_should_compress.unwrap_or_else(|| {
73            ShouldCompress(LitBool {
74                value: false,
75                span: Span::call_site(),
76            })
77        });
78
79        Ok(Self {
80            asset_file,
81            should_compress,
82        })
83    }
84}
85
86impl Parse for AssetFile {
87    fn parse(input: ParseStream) -> syn::Result<Self> {
88        let input_span = input.span();
89        let asset_file: LitStr = input.parse()?;
90        let literal = asset_file.value();
91        let path = Path::new(&literal);
92        let metadata = match fs::metadata(path) {
93            Ok(meta) => meta,
94            Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
95                return Err(syn::Error::new(
96                    input_span,
97                    format!("The specified asset file ({literal}) does not exist."),
98                ));
99            }
100            Err(e) => {
101                return Err(syn::Error::new(
102                    input_span,
103                    format!("Error reading file {literal}: {}", DisplayFullError(&e)),
104                ));
105            }
106        };
107
108        if metadata.is_dir() {
109            return Err(syn::Error::new(
110                input_span,
111                "The specified asset is a directory, not a file. Did you mean to call `embed_assets!` instead?",
112            ));
113        }
114
115        Ok(AssetFile(asset_file))
116    }
117}
118
119impl ToTokens for EmbedAsset {
120    fn to_tokens(&self, tokens: &mut TokenStream) {
121        let AssetFile(asset_file) = &self.asset_file;
122        let ShouldCompress(should_compress) = &self.should_compress;
123
124        let result = generate_static_handler(asset_file, should_compress);
125
126        match result {
127            Ok(value) => {
128                tokens.extend(quote! {
129                    #value
130                });
131            }
132            Err(err_message) => {
133                let error = syn::Error::new(Span::call_site(), err_message);
134                tokens.extend(error.to_compile_error());
135            }
136        }
137    }
138}
139
140struct EmbedAssets {
141    assets_dir: AssetsDir,
142    validated_ignore_dirs: IgnoreDirs,
143    should_compress: ShouldCompress,
144    should_strip_html_ext: ShouldStripHtmlExt,
145}
146
147impl Parse for EmbedAssets {
148    fn parse(input: ParseStream) -> syn::Result<Self> {
149        let assets_dir: AssetsDir = input.parse()?;
150
151        // Default to no compression
152        let mut maybe_should_compress = None;
153        let mut maybe_ignore_dirs = None;
154        let mut maybe_should_strip_html_ext = None;
155
156        while !input.is_empty() {
157            input.parse::<Token![,]>()?;
158            let key: Ident = input.parse()?;
159            input.parse::<Token![=]>()?;
160
161            match key.to_string().as_str() {
162                "compress" => {
163                    let value = input.parse()?;
164                    maybe_should_compress = Some(value);
165                }
166                "ignore_dirs" => {
167                    let value = input.parse()?;
168                    maybe_ignore_dirs = Some(value);
169                }
170                "strip_html_ext" => {
171                    let value = input.parse()?;
172                    maybe_should_strip_html_ext = Some(value);
173                }
174                _ => {
175                    return Err(syn::Error::new(
176                        key.span(),
177                        "Unknown key in embed_assets! macro. Expected `compress`, `ignore_dirs`, or `strip_html_ext`",
178                    ));
179                }
180            }
181        }
182
183        let should_compress = maybe_should_compress.unwrap_or_else(|| {
184            ShouldCompress(LitBool {
185                value: false,
186                span: Span::call_site(),
187            })
188        });
189
190        let should_strip_html_ext = maybe_should_strip_html_ext.unwrap_or_else(|| {
191            ShouldStripHtmlExt(LitBool {
192                value: false,
193                span: Span::call_site(),
194            })
195        });
196
197        let ignore_dirs_with_span = maybe_ignore_dirs.unwrap_or(IgnoreDirsWithSpan(vec![]));
198        let validated_ignore_dirs = validate_ignore_dirs(ignore_dirs_with_span, &assets_dir.0)?;
199
200        Ok(Self {
201            assets_dir,
202            validated_ignore_dirs,
203            should_compress,
204            should_strip_html_ext,
205        })
206    }
207}
208
209impl ToTokens for EmbedAssets {
210    fn to_tokens(&self, tokens: &mut TokenStream) {
211        let AssetsDir(assets_dir) = &self.assets_dir;
212        let ignore_dirs = &self.validated_ignore_dirs;
213        let ShouldCompress(should_compress) = &self.should_compress;
214        let ShouldStripHtmlExt(should_strip_html_ext) = &self.should_strip_html_ext;
215
216        let result = generate_static_routes(
217            assets_dir,
218            ignore_dirs,
219            should_compress,
220            should_strip_html_ext,
221        );
222
223        match result {
224            Ok(value) => {
225                tokens.extend(quote! {
226                    #value
227                });
228            }
229            Err(err_message) => {
230                let error = syn::Error::new(Span::call_site(), err_message);
231                tokens.extend(error.to_compile_error());
232            }
233        }
234    }
235}
236
237struct AssetsDir(LitStr);
238
239impl Parse for AssetsDir {
240    fn parse(input: ParseStream) -> syn::Result<Self> {
241        let input_span = input.span();
242        let assets_dir: LitStr = input.parse()?;
243        let literal = assets_dir.value();
244        let path = Path::new(&literal);
245        let metadata = match fs::metadata(path) {
246            Ok(meta) => meta,
247            Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
248                return Err(syn::Error::new(
249                    input_span,
250                    "The specified assets directory does not exist",
251                ));
252            }
253            Err(e) => {
254                return Err(syn::Error::new(
255                    input_span,
256                    format!(
257                        "Error reading directory {literal}: {}",
258                        DisplayFullError(&e)
259                    ),
260                ));
261            }
262        };
263
264        if !metadata.is_dir() {
265            return Err(syn::Error::new(
266                input_span,
267                "The specified assets directory is not a directory",
268            ));
269        }
270
271        Ok(AssetsDir(assets_dir))
272    }
273}
274
275struct IgnoreDirs(Vec<PathBuf>);
276
277struct IgnoreDirsWithSpan(Vec<(PathBuf, Span)>);
278
279impl Parse for IgnoreDirsWithSpan {
280    fn parse(input: ParseStream) -> syn::Result<Self> {
281        let inner_content;
282        bracketed!(inner_content in input);
283
284        let mut dirs = Vec::new();
285        while !inner_content.is_empty() {
286            let directory_span = inner_content.span();
287            let directory_str = inner_content.parse::<LitStr>()?;
288            let path = PathBuf::from(directory_str.value());
289            dirs.push((path, directory_span));
290
291            if !inner_content.is_empty() {
292                inner_content.parse::<Token![,]>()?;
293            }
294        }
295
296        Ok(IgnoreDirsWithSpan(dirs))
297    }
298}
299
300fn validate_ignore_dirs(
301    ignore_dirs: IgnoreDirsWithSpan,
302    assets_dir: &LitStr,
303) -> syn::Result<IgnoreDirs> {
304    let mut valid_ignore_dirs = Vec::new();
305    for (dir, span) in ignore_dirs.0 {
306        let full_path = PathBuf::from(assets_dir.value()).join(&dir);
307        match fs::metadata(&full_path) {
308            Ok(meta) if !meta.is_dir() => {
309                return Err(syn::Error::new(
310                    span,
311                    "The specified ignored directory is not a directory",
312                ));
313            }
314            Ok(_) => valid_ignore_dirs.push(full_path),
315            Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
316                return Err(syn::Error::new(
317                    span,
318                    "The specified ignored directory does not exist",
319                ))
320            }
321            Err(e) => {
322                return Err(syn::Error::new(
323                    span,
324                    format!(
325                        "Error reading ignored directory {}: {}",
326                        dir.to_string_lossy(),
327                        DisplayFullError(&e)
328                    ),
329                ))
330            }
331        }
332    }
333    Ok(IgnoreDirs(valid_ignore_dirs))
334}
335
336struct ShouldCompress(LitBool);
337
338impl Parse for ShouldCompress {
339    fn parse(input: ParseStream) -> syn::Result<Self> {
340        let lit = input.parse()?;
341        Ok(ShouldCompress(lit))
342    }
343}
344
345struct ShouldStripHtmlExt(LitBool);
346
347impl Parse for ShouldStripHtmlExt {
348    fn parse(input: ParseStream) -> syn::Result<Self> {
349        let lit = input.parse()?;
350        Ok(ShouldStripHtmlExt(lit))
351    }
352}
353
354fn generate_static_routes(
355    assets_dir: &LitStr,
356    ignore_dirs: &IgnoreDirs,
357    should_compress: &LitBool,
358    should_strip_html_ext: &LitBool,
359) -> Result<TokenStream, error::Error> {
360    let assets_dir_abs = Path::new(&assets_dir.value())
361        .canonicalize()
362        .map_err(Error::CannotCanonicalizeDirectory)?;
363    let assets_dir_abs_str = assets_dir_abs
364        .to_str()
365        .ok_or(Error::InvalidUnicodeInDirectoryName)?;
366    let canon_ignore_dirs = ignore_dirs
367        .0
368        .iter()
369        .map(|d| d.canonicalize().map_err(Error::CannotCanonicalizeIgnoreDir))
370        .collect::<Result<Vec<_>, _>>()?;
371
372    let mut routes = Vec::new();
373    for entry in glob(&format!("{assets_dir_abs_str}/**/*")).map_err(Error::Pattern)? {
374        let entry = entry.map_err(Error::Glob)?;
375        let metadata = entry.metadata().map_err(Error::CannotGetMetadata)?;
376        if metadata.is_dir() {
377            continue;
378        }
379
380        // Skip `entry`s which are located in ignored subdirectories
381        if canon_ignore_dirs
382            .iter()
383            .any(|ignore_dir| entry.starts_with(ignore_dir))
384        {
385            continue;
386        }
387
388        let EmbeddedFileInfo {
389            entry_path,
390            content_type,
391            etag_str,
392            lit_byte_str_contents,
393            maybe_gzip,
394            maybe_zstd,
395        } = EmbeddedFileInfo::from_path(
396            &entry,
397            Some(assets_dir_abs_str),
398            should_compress,
399            should_strip_html_ext,
400        )?;
401
402        routes.push(quote! {
403            router = ::static_serve::static_route(
404                router,
405                #entry_path,
406                #content_type,
407                #etag_str,
408                #lit_byte_str_contents,
409                #maybe_gzip,
410                #maybe_zstd,
411            );
412        });
413    }
414
415    Ok(quote! {
416    pub fn static_router<S>() -> ::axum::Router<S>
417        where S: ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static {
418            let mut router = ::axum::Router::<S>::new();
419            #(#routes)*
420            router
421        }
422    })
423}
424
425fn generate_static_handler(
426    asset_file: &LitStr,
427    should_compress: &LitBool,
428) -> Result<TokenStream, error::Error> {
429    let asset_file_abs = Path::new(&asset_file.value())
430        .canonicalize()
431        .map_err(Error::CannotCanonicalizeFile)?;
432
433    let EmbeddedFileInfo {
434        entry_path: _,
435        content_type,
436        etag_str,
437        lit_byte_str_contents,
438        maybe_gzip,
439        maybe_zstd,
440    } = EmbeddedFileInfo::from_path(
441        &asset_file_abs,
442        None,
443        should_compress,
444        &LitBool {
445            value: false,
446            span: Span::call_site(),
447        },
448    )?;
449
450    let route = quote! {
451        ::static_serve::static_method_router(
452            #content_type,
453            #etag_str,
454            #lit_byte_str_contents,
455            #maybe_gzip,
456            #maybe_zstd,
457        )
458    };
459
460    Ok(route)
461}
462
463struct OptionBytesSlice(Option<LitByteStr>);
464impl ToTokens for OptionBytesSlice {
465    fn to_tokens(&self, tokens: &mut TokenStream) {
466        tokens.extend(if let Some(inner) = &self.0.as_ref() {
467            quote! { ::std::option::Option::Some(#inner) }
468        } else {
469            quote! { ::std::option::Option::None }
470        });
471    }
472}
473
474struct EmbeddedFileInfo<'a> {
475    /// When creating a `Router`, we need the API path/route to the
476    /// target file. If creating a `Handler`, this is not needed since
477    /// the router is responsible for the file's path on the server.
478    entry_path: Option<&'a str>,
479    content_type: String,
480    etag_str: String,
481    lit_byte_str_contents: LitByteStr,
482    maybe_gzip: OptionBytesSlice,
483    maybe_zstd: OptionBytesSlice,
484}
485
486impl<'a> EmbeddedFileInfo<'a> {
487    fn from_path(
488        pathbuf: &'a PathBuf,
489        assets_dir_abs_str: Option<&str>,
490        should_compress: &LitBool,
491        should_strip_html_ext: &LitBool,
492    ) -> Result<Self, Error> {
493        let contents = fs::read(pathbuf).map_err(Error::CannotReadEntryContents)?;
494
495        // Optionally compress files
496        let (maybe_gzip, maybe_zstd) = if should_compress.value {
497            let gzip = gzip_compress(&contents)?;
498            let zstd = zstd_compress(&contents)?;
499            (gzip, zstd)
500        } else {
501            (None, None)
502        };
503
504        let content_type = file_content_type(pathbuf)?;
505
506        // entry_path is only needed for the router (embed_assets!)
507        let entry_path = if let Some(dir) = assets_dir_abs_str {
508            if should_strip_html_ext.value && content_type == "text/html" {
509                Some(
510                    strip_html_ext(pathbuf)?
511                        .strip_prefix(dir)
512                        .unwrap_or_default(),
513                )
514            } else {
515                pathbuf
516                    .to_str()
517                    .ok_or(Error::InvalidUnicodeInEntryName)?
518                    .strip_prefix(dir)
519            }
520        } else {
521            None
522        };
523
524        let etag_str = etag(&contents);
525        let lit_byte_str_contents = LitByteStr::new(&contents, Span::call_site());
526        let maybe_gzip = OptionBytesSlice(maybe_gzip);
527        let maybe_zstd = OptionBytesSlice(maybe_zstd);
528
529        Ok(Self {
530            entry_path,
531            content_type,
532            etag_str,
533            lit_byte_str_contents,
534            maybe_gzip,
535            maybe_zstd,
536        })
537    }
538}
539
540fn gzip_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
541    let mut compressor = GzEncoder::new(Vec::new(), flate2::Compression::best());
542    compressor
543        .write_all(contents)
544        .map_err(|e| Error::Gzip(GzipType::CompressorWrite(e)))?;
545    let compressed = compressor
546        .finish()
547        .map_err(|e| Error::Gzip(GzipType::EncoderFinish(e)))?;
548
549    Ok(maybe_get_compressed(&compressed, contents))
550}
551
552fn zstd_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
553    let level = *zstd::compression_level_range().end();
554    let mut encoder = zstd::Encoder::new(Vec::new(), level).unwrap();
555    write_to_zstd_encoder(&mut encoder, contents)
556        .map_err(|e| Error::Zstd(ZstdType::EncoderWrite(e)))?;
557
558    let compressed = encoder
559        .finish()
560        .map_err(|e| Error::Zstd(ZstdType::EncoderFinish(e)))?;
561
562    Ok(maybe_get_compressed(&compressed, contents))
563}
564
565fn write_to_zstd_encoder(
566    encoder: &mut zstd::Encoder<'static, Vec<u8>>,
567    contents: &[u8],
568) -> io::Result<()> {
569    encoder.set_pledged_src_size(Some(
570        contents
571            .len()
572            .try_into()
573            .expect("contents size should fit into u64"),
574    ))?;
575    encoder.window_log(23)?;
576    encoder.include_checksum(false)?;
577    encoder.include_contentsize(false)?;
578    encoder.long_distance_matching(false)?;
579    encoder.write_all(contents)?;
580
581    Ok(())
582}
583
584fn is_compression_significant(compressed_len: usize, contents_len: usize) -> bool {
585    let ninety_pct_original = contents_len / 10 * 9;
586    compressed_len < ninety_pct_original
587}
588
589fn maybe_get_compressed(compressed: &[u8], contents: &[u8]) -> Option<LitByteStr> {
590    is_compression_significant(compressed.len(), contents.len())
591        .then(|| LitByteStr::new(compressed, Span::call_site()))
592}
593
594/// Use `mime_guess` to get the best guess of the file's MIME type
595/// by looking at its extension, or return an error if unable.
596///
597/// We accept the first guess because [`mime_guess` updates the order
598/// according to the latest IETF RTC](https://docs.rs/mime_guess/2.0.5/mime_guess/struct.MimeGuess.html#note-ordering)
599fn file_content_type(path: &Path) -> Result<String, error::Error> {
600    match path.extension() {
601        Some(ext) => {
602            let guesses = mime_guess::MimeGuess::from_ext(
603                ext.to_str()
604                    .ok_or(error::Error::InvalidFileExtension(path.into()))?,
605            );
606
607            if let Some(guess) = guesses.first_raw() {
608                Ok(guess.to_owned())
609            } else {
610                Err(error::Error::UnknownFileExtension(
611                    path.extension().map(Into::into),
612                ))
613            }
614        }
615        None => Err(error::Error::UnknownFileExtension(None)),
616    }
617}
618
619fn etag(contents: &[u8]) -> String {
620    let sha256 = Sha1::digest(contents);
621    let hash = u64::from_le_bytes(sha256[..8].try_into().unwrap())
622        ^ u64::from_le_bytes(sha256[8..16].try_into().unwrap());
623    format!("\"{hash:016x}\"")
624}
625
626fn strip_html_ext(entry: &Path) -> Result<&str, Error> {
627    let entry_str = entry.to_str().ok_or(Error::InvalidUnicodeInEntryName)?;
628    let mut output = entry_str;
629
630    // Strip the extension
631    if let Some(prefix) = output.strip_suffix(".html") {
632        output = prefix;
633    } else if let Some(prefix) = output.strip_suffix(".htm") {
634        output = prefix;
635    }
636
637    // If it was `/index.html` or `/index.htm`, also remove `index`
638    if output.ends_with("/index") {
639        output = output.strip_suffix("index").unwrap_or("/");
640    }
641
642    Ok(output)
643}