1use std::{
6 convert::Into,
7 fmt::Display,
8 fs,
9 io::{self, Write},
10 path::{Path, PathBuf},
11};
12
13use display_full_error::DisplayFullError;
14use flate2::write::GzEncoder;
15use glob::glob;
16use proc_macro2::{Span, TokenStream};
17use quote::{quote, ToTokens};
18use sha1::{Digest as _, Sha1};
19use syn::{
20 bracketed,
21 parse::{Parse, ParseStream},
22 parse_macro_input, Ident, LitBool, LitByteStr, LitStr, Token,
23};
24
25mod error;
26use error::{Error, GzipType, ZstdType};
27
28#[proc_macro]
29pub fn embed_assets(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 let parsed = parse_macro_input!(input as EmbedAssets);
32 quote! { #parsed }.into()
33}
34
35struct EmbedAssets {
36 assets_dir: AssetsDir,
37 validated_ignore_dirs: IgnoreDirs,
38 should_compress: ShouldCompress,
39}
40
41impl Parse for EmbedAssets {
42 fn parse(input: ParseStream) -> syn::Result<Self> {
43 let assets_dir: AssetsDir = input.parse()?;
44
45 let mut maybe_should_compress = None;
47 let mut maybe_ignore_dirs = None;
48
49 while !input.is_empty() {
50 input.parse::<Token![,]>()?;
51 let key: Ident = input.parse()?;
52 input.parse::<Token![=]>()?;
53
54 match key.to_string().as_str() {
55 "compress" => {
56 let value = input.parse()?;
57 maybe_should_compress = Some(value);
58 }
59 "ignore_dirs" => {
60 let value = input.parse()?;
61 maybe_ignore_dirs = Some(value);
62 }
63 _ => {
64 return Err(syn::Error::new(
65 key.span(),
66 "Unknown key in embed_assets! macro. Expected `compress` or `ignore_dirs`",
67 ));
68 }
69 }
70 }
71
72 let should_compress = maybe_should_compress.unwrap_or_else(|| {
73 ShouldCompress(LitBool {
74 value: false,
75 span: Span::call_site(),
76 })
77 });
78
79 let ignore_dirs_with_span = maybe_ignore_dirs.unwrap_or(IgnoreDirsWithSpan(vec![]));
80 let validated_ignore_dirs = validate_ignore_dirs(ignore_dirs_with_span, &assets_dir.0)?;
81
82 Ok(Self {
83 assets_dir,
84 validated_ignore_dirs,
85 should_compress,
86 })
87 }
88}
89
90impl ToTokens for EmbedAssets {
91 fn to_tokens(&self, tokens: &mut TokenStream) {
92 let AssetsDir(assets_dir) = &self.assets_dir;
93 let ignore_dirs = &self.validated_ignore_dirs;
94 let ShouldCompress(should_compress) = &self.should_compress;
95
96 let result = generate_static_routes(assets_dir, ignore_dirs, should_compress);
97
98 match result {
99 Ok(value) => {
100 tokens.extend(quote! {
101 #value
102 });
103 }
104 Err(err_message) => {
105 let error = syn::Error::new(Span::call_site(), err_message);
106 tokens.extend(error.to_compile_error());
107 }
108 }
109 }
110}
111
112struct AssetsDir(ValidAssetsDirTypes);
113
114impl Parse for AssetsDir {
115 fn parse(input: ParseStream) -> syn::Result<Self> {
116 let input_span = input.span();
117 let assets_dir: ValidAssetsDirTypes = input.parse()?;
118 let literal = assets_dir.to_string();
119 let path = Path::new(&literal);
120 let metadata = match fs::metadata(path) {
121 Ok(meta) => meta,
122 Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
123 return Err(syn::Error::new(
124 input_span,
125 "The specified assets directory does not exist",
126 ));
127 }
128 Err(e) => {
129 return Err(syn::Error::new(
130 input_span,
131 format!(
132 "Error reading directory {literal}: {}",
133 DisplayFullError(&e)
134 ),
135 ));
136 }
137 };
138
139 if !metadata.is_dir() {
140 return Err(syn::Error::new(
141 input_span,
142 "The specified assets directory is not a directory",
143 ));
144 }
145
146 Ok(AssetsDir(assets_dir))
147 }
148}
149
150enum ValidAssetsDirTypes {
151 LiteralStr(LitStr),
152 Ident(Ident),
153}
154
155impl Display for ValidAssetsDirTypes {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 Self::LiteralStr(inner) => write!(f, "{}", inner.value()),
159 Self::Ident(inner) => write!(f, "{inner}"),
160 }
161 }
162}
163
164impl Parse for ValidAssetsDirTypes {
165 fn parse(input: ParseStream) -> syn::Result<Self> {
166 if let Ok(inner) = input.parse::<LitStr>() {
167 Ok(ValidAssetsDirTypes::LiteralStr(inner))
168 } else {
169 let inner = input.parse::<Ident>().map_err(|_| {
170 syn::Error::new(
171 input.span(),
172 "Assets directory must be a literal string or valid identifier",
173 )
174 })?;
175 Ok(ValidAssetsDirTypes::Ident(inner))
176 }
177 }
178}
179
180struct IgnoreDirs(Vec<PathBuf>);
181
182struct IgnoreDirsWithSpan(Vec<(PathBuf, Span)>);
183
184impl Parse for IgnoreDirsWithSpan {
185 fn parse(input: ParseStream) -> syn::Result<Self> {
186 let inner_content;
187 bracketed!(inner_content in input);
188
189 let mut dirs = Vec::new();
190 while !inner_content.is_empty() {
191 let directory_span = inner_content.span();
192 let directory_str = inner_content.parse::<LitStr>()?;
193 if !inner_content.is_empty() {
194 inner_content.parse::<Token![,]>()?;
195 }
196 let path = PathBuf::from(directory_str.value());
197 dirs.push((path, directory_span));
198 }
199
200 Ok(IgnoreDirsWithSpan(dirs))
201 }
202}
203
204fn validate_ignore_dirs(
205 ignore_dirs: IgnoreDirsWithSpan,
206 assets_dir: &ValidAssetsDirTypes,
207) -> syn::Result<IgnoreDirs> {
208 let mut valid_ignore_dirs = Vec::new();
209 for (dir, span) in ignore_dirs.0 {
210 let full_path = PathBuf::from(assets_dir.to_string()).join(&dir);
211 match fs::metadata(&full_path) {
212 Ok(meta) if !meta.is_dir() => {
213 return Err(syn::Error::new(
214 span,
215 "The specified ignored directory is not a directory",
216 ));
217 }
218 Ok(_) => valid_ignore_dirs.push(full_path),
219 Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => {
220 return Err(syn::Error::new(
221 span,
222 "The specified ignored directory does not exist",
223 ))
224 }
225 Err(e) => {
226 return Err(syn::Error::new(
227 span,
228 format!(
229 "Error reading ignored directory {}: {}",
230 dir.to_string_lossy(),
231 DisplayFullError(&e)
232 ),
233 ))
234 }
235 };
236 }
237 Ok(IgnoreDirs(valid_ignore_dirs))
238}
239
240struct ShouldCompress(LitBool);
241
242impl Parse for ShouldCompress {
243 fn parse(input: ParseStream) -> syn::Result<Self> {
244 let lit = input.parse()?;
245 Ok(ShouldCompress(lit))
246 }
247}
248
249fn generate_static_routes(
250 assets_dir: &ValidAssetsDirTypes,
251 ignore_dirs: &IgnoreDirs,
252 should_compress: &LitBool,
253) -> Result<TokenStream, error::Error> {
254 let assets_dir_abs = Path::new(&assets_dir.to_string())
255 .canonicalize()
256 .map_err(Error::CannotCanonicalizeDirectory)?;
257 let assets_dir_abs_str = assets_dir_abs
258 .to_str()
259 .ok_or(Error::InvalidUnicodeInDirectoryName)?;
260 let canon_ignore_dirs = ignore_dirs
261 .0
262 .iter()
263 .map(|d| d.canonicalize().map_err(Error::CannotCanonicalizeIgnoreDir))
264 .collect::<Result<Vec<_>, _>>()?;
265
266 let mut routes = Vec::new();
267 for entry in glob(&format!("{assets_dir_abs_str}/**/*")).map_err(Error::Pattern)? {
268 let entry = entry.map_err(Error::Glob)?;
269 let metadata = entry.metadata().map_err(Error::CannotGetMetadata)?;
270 if metadata.is_dir() {
271 continue;
272 }
273
274 if canon_ignore_dirs
276 .iter()
277 .any(|ignore_dir| entry.starts_with(ignore_dir))
278 {
279 continue;
280 }
281
282 let contents = fs::read(&entry).map_err(Error::CannotReadEntryContents)?;
283
284 let (maybe_gzip, maybe_zstd) = if should_compress.value {
286 let gzip = gzip_compress(&contents)?;
287 let zstd = zstd_compress(&contents)?;
288 (gzip, zstd)
289 } else {
290 (None, None)
291 };
292
293 let entry_path = entry
295 .to_str()
296 .ok_or(Error::InvalidUnicodeInEntryName)?
297 .strip_prefix(assets_dir_abs_str)
298 .unwrap_or_default();
299 let content_type = file_content_type(&entry)?;
300 let etag_str = etag(&contents);
301 let lit_byte_str_contents = LitByteStr::new(&contents, Span::call_site());
302 let maybe_gzip = option_to_token_stream_option(maybe_gzip.as_ref());
303 let maybe_zstd = option_to_token_stream_option(maybe_zstd.as_ref());
304
305 routes.push(quote! {
306 router = ::static_serve::static_route(
307 router,
308 #entry_path,
309 #content_type,
310 #etag_str,
311 #lit_byte_str_contents,
312 #maybe_gzip,
313 #maybe_zstd,
314 );
315 });
316 }
317
318 Ok(quote! {
319 pub fn static_router<S>() -> ::axum::Router<S>
320 where S: ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static {
321 let mut router = ::axum::Router::<S>::new();
322 #(#routes)*
323 router
324 }
325 })
326}
327
328fn gzip_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
329 let mut compressor = GzEncoder::new(Vec::new(), flate2::Compression::best());
330 compressor
331 .write_all(contents)
332 .map_err(|e| Error::Gzip(GzipType::CompressorWrite(e)))?;
333 let compressed = compressor
334 .finish()
335 .map_err(|e| Error::Gzip(GzipType::EncoderFinish(e)))?;
336
337 Ok(maybe_get_compressed(&compressed, contents))
338}
339
340fn zstd_compress(contents: &[u8]) -> Result<Option<LitByteStr>, Error> {
341 let level = *zstd::compression_level_range().end();
342 let mut encoder = zstd::Encoder::new(Vec::new(), level).unwrap();
343 write_to_zstd_encoder(&mut encoder, contents)
344 .map_err(|e| Error::Zstd(ZstdType::EncoderWrite(e)))?;
345
346 let compressed = encoder
347 .finish()
348 .map_err(|e| Error::Zstd(ZstdType::EncoderFinish(e)))?;
349
350 Ok(maybe_get_compressed(&compressed, contents))
351}
352
353fn write_to_zstd_encoder(
354 encoder: &mut zstd::Encoder<'static, Vec<u8>>,
355 contents: &[u8],
356) -> io::Result<()> {
357 encoder.set_pledged_src_size(Some(
358 contents
359 .len()
360 .try_into()
361 .expect("contents size should fit into u64"),
362 ))?;
363 encoder.window_log(23)?;
364 encoder.include_checksum(false)?;
365 encoder.include_contentsize(false)?;
366 encoder.long_distance_matching(false)?;
367 encoder.write_all(contents)?;
368
369 Ok(())
370}
371
372fn option_to_token_stream_option<T: ToTokens>(opt: Option<&T>) -> TokenStream {
373 if let Some(inner) = opt {
374 quote! { ::std::option::Option::Some(#inner) }
375 } else {
376 quote! { ::std::option::Option::None }
377 }
378}
379
380fn is_compression_significant(compressed_len: usize, contents_len: usize) -> bool {
381 let ninety_pct_original = contents_len / 10 * 9;
382 compressed_len < ninety_pct_original
383}
384
385fn maybe_get_compressed(compressed: &[u8], contents: &[u8]) -> Option<LitByteStr> {
386 is_compression_significant(compressed.len(), contents.len())
387 .then(|| LitByteStr::new(compressed, Span::call_site()))
388}
389
390fn file_content_type(path: &Path) -> Result<&'static str, error::Error> {
391 match path.extension() {
392 Some(ext) if ext.eq_ignore_ascii_case("css") => Ok("text/css"),
393 Some(ext) if ext.eq_ignore_ascii_case("js") => Ok("text/javascript"),
394 Some(ext) if ext.eq_ignore_ascii_case("txt") => Ok("text/plain"),
395 Some(ext) if ext.eq_ignore_ascii_case("woff") => Ok("font/woff"),
396 Some(ext) if ext.eq_ignore_ascii_case("woff2") => Ok("font/woff2"),
397 Some(ext) if ext.eq_ignore_ascii_case("svg") => Ok("image/svg+xml"),
398 ext => Err(error::Error::UnknownFileExtension(ext.map(Into::into))),
399 }
400}
401
402fn etag(contents: &[u8]) -> String {
403 let sha256 = Sha1::digest(contents);
404 let hash = u64::from_le_bytes(sha256[..8].try_into().unwrap())
405 ^ u64::from_le_bytes(sha256[8..16].try_into().unwrap());
406 format!("\"{hash:016x}\"")
407}