1use std::{
6 convert::Into,
7 fs,
8 io::{self, Write},
9 path::{Path, PathBuf},
10};
11
12use display_full_error::DisplayFullError;
13use flate2::write::GzEncoder;
14use glob::glob;
15use proc_macro2::{Span, TokenStream};
16use quote::{quote, ToTokens};
17use sha1::{Digest as _, Sha1};
18use syn::{
19 bracketed,
20 parse::{Parse, ParseStream},
21 parse_macro_input, Ident, LitBool, LitByteStr, LitStr, Token,
22};
23
24mod error;
25use error::{Error, GzipType, ZstdType};
26
27#[proc_macro]
28pub fn embed_assets(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 let parsed = parse_macro_input!(input as EmbedAssets);
31 quote! { #parsed }.into()
32}
33
34struct EmbedAssets {
35 assets_dir: AssetsDir,
36 validated_ignore_dirs: IgnoreDirs,
37 should_compress: ShouldCompress,
38}
39
40impl Parse for EmbedAssets {
41 fn parse(input: ParseStream) -> syn::Result<Self> {
42 let assets_dir: AssetsDir = input.parse()?;
43
44 let mut maybe_should_compress = None;
46 let mut maybe_ignore_dirs = None;
47
48 while !input.is_empty() {
49 input.parse::<Token![,]>()?;
50 let key: Ident = input.parse()?;
51 input.parse::<Token![=]>()?;
52
53 match key.to_string().as_str() {
54 "compress" => {
55 let value = input.parse()?;
56 maybe_should_compress = Some(value);
57 }
58 "ignore_dirs" => {
59 let value = input.parse()?;
60 maybe_ignore_dirs = Some(value);
61 }
62 _ => {
63 return Err(syn::Error::new(
64 key.span(),
65 "Unknown key in embed_assets! macro. Expected `compress` or `ignore_dirs`",
66 ));
67 }
68 }
69 }
70
71 let should_compress = maybe_should_compress.unwrap_or_else(|| {
72 ShouldCompress(LitBool {
73 value: false,
74 span: Span::call_site(),
75 })
76 });
77
78 let ignore_dirs_with_span = maybe_ignore_dirs.unwrap_or(IgnoreDirsWithSpan(vec![]));
79 let validated_ignore_dirs = validate_ignore_dirs(ignore_dirs_with_span, &assets_dir.0)?;
80
81 Ok(Self {
82 assets_dir,
83 validated_ignore_dirs,
84 should_compress,
85 })
86 }
87}
88
89impl ToTokens for EmbedAssets {
90 fn to_tokens(&self, tokens: &mut TokenStream) {
91 let AssetsDir(assets_dir) = &self.assets_dir;
92 let ignore_dirs = &self.validated_ignore_dirs;
93 let ShouldCompress(should_compress) = &self.should_compress;
94
95 let result = generate_static_routes(assets_dir, ignore_dirs, should_compress);
96
97 match result {
98 Ok(value) => {
99 tokens.extend(quote! {
100 #value
101 });
102 }
103 Err(err_message) => {
104 let error = syn::Error::new(Span::call_site(), err_message);
105 tokens.extend(error.to_compile_error());
106 }
107 }
108 }
109}
110
111struct AssetsDir(LitStr);
112
113impl Parse for AssetsDir {
114 fn parse(input: ParseStream) -> syn::Result<Self> {
115 let input_span = input.span();
116 let assets_dir: LitStr = input.parse()?;
117 let literal = assets_dir.value();
118 let path = Path::new(&literal);
119 let metadata = match fs::metadata(path) {
120 Ok(meta) => meta,
121 Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
122 return Err(syn::Error::new(
123 input_span,
124 "The specified assets directory does not exist",
125 ));
126 }
127 Err(e) => {
128 return Err(syn::Error::new(
129 input_span,
130 format!(
131 "Error reading directory {literal}: {}",
132 DisplayFullError(&e)
133 ),
134 ));
135 }
136 };
137
138 if !metadata.is_dir() {
139 return Err(syn::Error::new(
140 input_span,
141 "The specified assets directory is not a directory",
142 ));
143 }
144
145 Ok(AssetsDir(assets_dir))
146 }
147}
148
149struct IgnoreDirs(Vec<PathBuf>);
150
151struct IgnoreDirsWithSpan(Vec<(PathBuf, Span)>);
152
153impl Parse for IgnoreDirsWithSpan {
154 fn parse(input: ParseStream) -> syn::Result<Self> {
155 let inner_content;
156 bracketed!(inner_content in input);
157
158 let mut dirs = Vec::new();
159 while !inner_content.is_empty() {
160 let directory_span = inner_content.span();
161 let directory_str = inner_content.parse::<LitStr>()?;
162 let path = PathBuf::from(directory_str.value());
163 dirs.push((path, directory_span));
164
165 if !inner_content.is_empty() {
166 inner_content.parse::<Token![,]>()?;
167 }
168 }
169
170 Ok(IgnoreDirsWithSpan(dirs))
171 }
172}
173
174fn validate_ignore_dirs(
175 ignore_dirs: IgnoreDirsWithSpan,
176 assets_dir: &LitStr,
177) -> syn::Result<IgnoreDirs> {
178 let mut valid_ignore_dirs = Vec::new();
179 for (dir, span) in ignore_dirs.0 {
180 let full_path = PathBuf::from(assets_dir.value()).join(&dir);
181 match fs::metadata(&full_path) {
182 Ok(meta) if !meta.is_dir() => {
183 return Err(syn::Error::new(
184 span,
185 "The specified ignored directory is not a directory",
186 ));
187 }
188 Ok(_) => valid_ignore_dirs.push(full_path),
189 Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
190 return Err(syn::Error::new(
191 span,
192 "The specified ignored directory does not exist",
193 ))
194 }
195 Err(e) => {
196 return Err(syn::Error::new(
197 span,
198 format!(
199 "Error reading ignored directory {}: {}",
200 dir.to_string_lossy(),
201 DisplayFullError(&e)
202 ),
203 ))
204 }
205 };
206 }
207 Ok(IgnoreDirs(valid_ignore_dirs))
208}
209
210struct ShouldCompress(LitBool);
211
212impl Parse for ShouldCompress {
213 fn parse(input: ParseStream) -> syn::Result<Self> {
214 let lit = input.parse()?;
215 Ok(ShouldCompress(lit))
216 }
217}
218
219fn generate_static_routes(
220 assets_dir: &LitStr,
221 ignore_dirs: &IgnoreDirs,
222 should_compress: &LitBool,
223) -> Result<TokenStream, error::Error> {
224 let assets_dir_abs = Path::new(&assets_dir.value())
225 .canonicalize()
226 .map_err(Error::CannotCanonicalizeDirectory)?;
227 let assets_dir_abs_str = assets_dir_abs
228 .to_str()
229 .ok_or(Error::InvalidUnicodeInDirectoryName)?;
230 let canon_ignore_dirs = ignore_dirs
231 .0
232 .iter()
233 .map(|d| d.canonicalize().map_err(Error::CannotCanonicalizeIgnoreDir))
234 .collect::<Result<Vec<_>, _>>()?;
235
236 let mut routes = Vec::new();
237 for entry in glob(&format!("{assets_dir_abs_str}/**/*")).map_err(Error::Pattern)? {
238 let entry = entry.map_err(Error::Glob)?;
239 let metadata = entry.metadata().map_err(Error::CannotGetMetadata)?;
240 if metadata.is_dir() {
241 continue;
242 }
243
244 if canon_ignore_dirs
246 .iter()
247 .any(|ignore_dir| entry.starts_with(ignore_dir))
248 {
249 continue;
250 }
251
252 let contents = fs::read(&entry).map_err(Error::CannotReadEntryContents)?;
253
254 let (maybe_gzip, maybe_zstd) = if should_compress.value {
256 let gzip = gzip_compress(&contents)?;
257 let zstd = zstd_compress(&contents)?;
258 (gzip, zstd)
259 } else {
260 (None, None)
261 };
262
263 let entry_path = entry
265 .to_str()
266 .ok_or(Error::InvalidUnicodeInEntryName)?
267 .strip_prefix(assets_dir_abs_str)
268 .unwrap_or_default();
269 let content_type = file_content_type(&entry)?;
270 let etag_str = etag(&contents);
271 let lit_byte_str_contents = LitByteStr::new(&contents, Span::call_site());
272 let maybe_gzip = option_to_token_stream_option(maybe_gzip.as_ref());
273 let maybe_zstd = option_to_token_stream_option(maybe_zstd.as_ref());
274
275 routes.push(quote! {
276 router = ::static_serve::static_route(
277 router,
278 #entry_path,
279 #content_type,
280 #etag_str,
281 #lit_byte_str_contents,
282 #maybe_gzip,
283 #maybe_zstd,
284 );
285 });
286 }
287
288 Ok(quote! {
289 pub fn static_router<S>() -> ::axum::Router<S>
290 where S: ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static {
291 let mut router = ::axum::Router::<S>::new();
292 #(#routes)*
293 router
294 }
295 })
296}
297
298fn gzip_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
299 let mut compressor = GzEncoder::new(Vec::new(), flate2::Compression::best());
300 compressor
301 .write_all(contents)
302 .map_err(|e| Error::Gzip(GzipType::CompressorWrite(e)))?;
303 let compressed = compressor
304 .finish()
305 .map_err(|e| Error::Gzip(GzipType::EncoderFinish(e)))?;
306
307 Ok(maybe_get_compressed(&compressed, contents))
308}
309
310fn zstd_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
311 let level = *zstd::compression_level_range().end();
312 let mut encoder = zstd::Encoder::new(Vec::new(), level).unwrap();
313 write_to_zstd_encoder(&mut encoder, contents)
314 .map_err(|e| Error::Zstd(ZstdType::EncoderWrite(e)))?;
315
316 let compressed = encoder
317 .finish()
318 .map_err(|e| Error::Zstd(ZstdType::EncoderFinish(e)))?;
319
320 Ok(maybe_get_compressed(&compressed, contents))
321}
322
323fn write_to_zstd_encoder(
324 encoder: &mut zstd::Encoder<'static, Vec<u8>>,
325 contents: &[u8],
326) -> io::Result<()> {
327 encoder.set_pledged_src_size(Some(
328 contents
329 .len()
330 .try_into()
331 .expect("contents size should fit into u64"),
332 ))?;
333 encoder.window_log(23)?;
334 encoder.include_checksum(false)?;
335 encoder.include_contentsize(false)?;
336 encoder.long_distance_matching(false)?;
337 encoder.write_all(contents)?;
338
339 Ok(())
340}
341
342fn option_to_token_stream_option<T: ToTokens>(opt: Option<&T>) -> TokenStream {
343 if let Some(inner) = opt {
344 quote! { ::std::option::Option::Some(#inner) }
345 } else {
346 quote! { ::std::option::Option::None }
347 }
348}
349
350fn is_compression_significant(compressed_len: usize, contents_len: usize) -> bool {
351 let ninety_pct_original = contents_len / 10 * 9;
352 compressed_len < ninety_pct_original
353}
354
355fn maybe_get_compressed(compressed: &[u8], contents: &[u8]) -> Option<LitByteStr> {
356 is_compression_significant(compressed.len(), contents.len())
357 .then(|| LitByteStr::new(compressed, Span::call_site()))
358}
359
360fn file_content_type(path: &Path) -> Result<&'static str, error::Error> {
361 match path.extension() {
362 Some(ext) if ext.eq_ignore_ascii_case("css") => Ok("text/css"),
363 Some(ext) if ext.eq_ignore_ascii_case("js") => Ok("text/javascript"),
364 Some(ext) if ext.eq_ignore_ascii_case("txt") => Ok("text/plain"),
365 Some(ext) if ext.eq_ignore_ascii_case("woff") => Ok("font/woff"),
366 Some(ext) if ext.eq_ignore_ascii_case("woff2") => Ok("font/woff2"),
367 Some(ext) if ext.eq_ignore_ascii_case("svg") => Ok("image/svg+xml"),
368 ext => Err(error::Error::UnknownFileExtension(ext.map(Into::into))),
369 }
370}
371
372fn etag(contents: &[u8]) -> String {
373 let sha256 = Sha1::digest(contents);
374 let hash = u64::from_le_bytes(sha256[..8].try_into().unwrap())
375 ^ u64::from_le_bytes(sha256[8..16].try_into().unwrap());
376 format!("\"{hash:016x}\"")
377}