rust_embed_impl/
lib.rs

1#![recursion_limit = "1024"]
2#![forbid(unsafe_code)]
3#[macro_use]
4extern crate quote;
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use rust_embed_utils::PathMatcher;
10use std::{
11  collections::BTreeMap,
12  env,
13  io::ErrorKind,
14  iter::FromIterator,
15  path::{Path, PathBuf},
16};
17use syn::{parse_macro_input, Data, DeriveInput, Expr, ExprLit, Fields, Lit, Meta, MetaNameValue};
18
19fn embedded(
20  ident: &syn::Ident, relative_folder_path: Option<&str>, absolute_folder_path: String, prefix: Option<&str>, includes: &[String], excludes: &[String],
21  metadata_only: bool, crate_path: &syn::Path,
22) -> syn::Result<TokenStream2> {
23  extern crate rust_embed_utils;
24
25  let mut match_values = BTreeMap::new();
26  let mut list_values = Vec::<String>::new();
27
28  let includes: Vec<&str> = includes.iter().map(AsRef::as_ref).collect();
29  let excludes: Vec<&str> = excludes.iter().map(AsRef::as_ref).collect();
30  let matcher = PathMatcher::new(&includes, &excludes);
31  for rust_embed_utils::FileEntry { rel_path, full_canonical_path } in rust_embed_utils::get_files(absolute_folder_path.clone(), matcher) {
32    match_values.insert(
33      rel_path.clone(),
34      embed_file(relative_folder_path, ident, &rel_path, &full_canonical_path, metadata_only, crate_path)?,
35    );
36
37    list_values.push(if let Some(prefix) = prefix {
38      format!("{}{}", prefix, rel_path)
39    } else {
40      rel_path
41    });
42  }
43
44  let array_len = list_values.len();
45
46  // If debug-embed is on, unconditionally include the code below. Otherwise,
47  // make it conditional on cfg(not(debug_assertions)).
48  let not_debug_attr = if cfg!(feature = "debug-embed") {
49    quote! {}
50  } else {
51    quote! { #[cfg(not(debug_assertions))]}
52  };
53
54  let handle_prefix = if let Some(prefix) = prefix {
55    quote! {
56      let file_path = file_path.strip_prefix(#prefix)?;
57    }
58  } else {
59    TokenStream2::new()
60  };
61  let match_values = match_values.into_iter().map(|(path, bytes)| {
62    quote! {
63        (#path, #bytes),
64    }
65  });
66  let value_type = if cfg!(feature = "compression") {
67    quote! { fn() -> #crate_path::EmbeddedFile }
68  } else {
69    quote! { #crate_path::EmbeddedFile }
70  };
71  let get_value = if cfg!(feature = "compression") {
72    quote! {|idx| (ENTRIES[idx].1)()}
73  } else {
74    quote! {|idx| ENTRIES[idx].1.clone()}
75  };
76  Ok(quote! {
77      #not_debug_attr
78      impl #ident {
79          /// Get an embedded file and its metadata.
80          pub fn get(file_path: &str) -> ::std::option::Option<#crate_path::EmbeddedFile> {
81            #handle_prefix
82            let key = file_path.replace("\\", "/");
83            const ENTRIES: &'static [(&'static str, #value_type)] = &[
84                #(#match_values)*];
85            let position = ENTRIES.binary_search_by_key(&key.as_str(), |entry| entry.0);
86            position.ok().map(#get_value)
87
88          }
89
90          fn names() -> ::std::slice::Iter<'static, &'static str> {
91              const ITEMS: [&str; #array_len] = [#(#list_values),*];
92              ITEMS.iter()
93          }
94
95          /// Iterates over the file paths in the folder.
96          pub fn iter() -> impl ::std::iter::Iterator<Item = ::std::borrow::Cow<'static, str>> {
97              Self::names().map(|x| ::std::borrow::Cow::from(*x))
98          }
99      }
100
101      #not_debug_attr
102      impl #crate_path::RustEmbed for #ident {
103        fn get(file_path: &str) -> ::std::option::Option<#crate_path::EmbeddedFile> {
104          #ident::get(file_path)
105        }
106        fn iter() -> #crate_path::Filenames {
107          #crate_path::Filenames::Embedded(#ident::names())
108        }
109      }
110  })
111}
112
113fn dynamic(
114  ident: &syn::Ident, folder_path: String, prefix: Option<&str>, includes: &[String], excludes: &[String], metadata_only: bool, crate_path: &syn::Path,
115) -> TokenStream2 {
116  let (handle_prefix, map_iter) = if let ::std::option::Option::Some(prefix) = prefix {
117    (
118      quote! { let file_path = file_path.strip_prefix(#prefix)?; },
119      quote! { ::std::borrow::Cow::Owned(format!("{}{}", #prefix, e.rel_path)) },
120    )
121  } else {
122    (TokenStream2::new(), quote! { ::std::borrow::Cow::from(e.rel_path) })
123  };
124
125  let declare_includes = quote! {
126    const INCLUDES: &[&str] = &[#(#includes),*];
127  };
128
129  let declare_excludes = quote! {
130    const EXCLUDES: &[&str] = &[#(#excludes),*];
131  };
132
133  // In metadata_only mode, we still need to read file contents to generate the
134  // file hash, but then we drop the file data.
135  let strip_contents = metadata_only.then_some(quote! {
136      .map(|mut file| { file.data = ::std::default::Default::default(); file })
137  });
138
139  let non_canonical_folder_path = Path::new(&folder_path);
140  let canonical_folder_path = non_canonical_folder_path
141    .canonicalize()
142    .or_else(|err| match err {
143      err if err.kind() == ErrorKind::NotFound => Ok(non_canonical_folder_path.to_owned()),
144      err => Err(err),
145    })
146    .expect("folder path must resolve to an absolute path");
147  let canonical_folder_path = canonical_folder_path.to_str().expect("absolute folder path must be valid unicode");
148
149  quote! {
150      #[cfg(debug_assertions)]
151      impl #ident {
152
153
154        fn matcher() -> #crate_path::utils::PathMatcher {
155            #declare_includes
156            #declare_excludes
157            static PATH_MATCHER: ::std::sync::OnceLock<#crate_path::utils::PathMatcher> = ::std::sync::OnceLock::new();
158            PATH_MATCHER.get_or_init(|| #crate_path::utils::PathMatcher::new(INCLUDES, EXCLUDES)).clone()
159        }
160          /// Get an embedded file and its metadata.
161          pub fn get(file_path: &str) -> ::std::option::Option<#crate_path::EmbeddedFile> {
162              #handle_prefix
163
164              let rel_file_path = file_path.replace("\\", "/");
165              let file_path = ::std::path::Path::new(#folder_path).join(&rel_file_path);
166
167              // Make sure the path requested does not escape the folder path
168              let canonical_file_path = file_path.canonicalize().ok()?;
169              if !canonical_file_path.starts_with(#canonical_folder_path) {
170                  // Tried to request a path that is not in the embedded folder
171
172                  // TODO: Currently it allows "path_traversal_attack" for the symlink files
173                  // For it to be working properly we need to get absolute path first
174                  // and check that instead if it starts with `canonical_folder_path`
175                  // https://doc.rust-lang.org/std/path/fn.absolute.html (currently nightly)
176                  // Should be allowed only if it was a symlink
177                  let metadata = ::std::fs::symlink_metadata(&file_path).ok()?;
178                  if !metadata.is_symlink() {
179                    return ::std::option::Option::None;
180                  }
181              }
182              let path_matcher = Self::matcher();
183              if path_matcher.is_path_included(&rel_file_path) {
184                #crate_path::utils::read_file_from_fs(&canonical_file_path).ok() #strip_contents
185              } else {
186                ::std::option::Option::None
187              }
188          }
189
190          /// Iterates over the file paths in the folder.
191          pub fn iter() -> impl ::std::iter::Iterator<Item = ::std::borrow::Cow<'static, str>> {
192              use ::std::path::Path;
193
194
195              #crate_path::utils::get_files(::std::string::String::from(#folder_path), Self::matcher())
196                  .map(|e| #map_iter)
197          }
198      }
199
200      #[cfg(debug_assertions)]
201      impl #crate_path::RustEmbed for #ident {
202        fn get(file_path: &str) -> ::std::option::Option<#crate_path::EmbeddedFile> {
203          #ident::get(file_path)
204        }
205        fn iter() -> #crate_path::Filenames {
206          // the return type of iter() is unnamable, so we have to box it
207          #crate_path::Filenames::Dynamic(::std::boxed::Box::new(#ident::iter()))
208        }
209      }
210  }
211}
212
213fn generate_assets(
214  ident: &syn::Ident, relative_folder_path: Option<&str>, absolute_folder_path: String, prefix: Option<String>, includes: Vec<String>, excludes: Vec<String>,
215  metadata_only: bool, crate_path: &syn::Path,
216) -> syn::Result<TokenStream2> {
217  let embedded_impl = embedded(
218    ident,
219    relative_folder_path,
220    absolute_folder_path.clone(),
221    prefix.as_deref(),
222    &includes,
223    &excludes,
224    metadata_only,
225    crate_path,
226  );
227  if cfg!(feature = "debug-embed") {
228    return embedded_impl;
229  }
230  let embedded_impl = embedded_impl?;
231  let dynamic_impl = dynamic(ident, absolute_folder_path, prefix.as_deref(), &includes, &excludes, metadata_only, crate_path);
232
233  Ok(quote! {
234      #embedded_impl
235      #dynamic_impl
236  })
237}
238
239fn embed_file(
240  folder_path: Option<&str>, ident: &syn::Ident, rel_path: &str, full_canonical_path: &str, metadata_only: bool, crate_path: &syn::Path,
241) -> syn::Result<TokenStream2> {
242  let file = rust_embed_utils::read_file_from_fs(Path::new(full_canonical_path)).expect("File should be readable");
243  let hash = file.metadata.sha256_hash();
244  let (last_modified, created) = if cfg!(feature = "deterministic-timestamps") {
245    (quote! { ::std::option::Option::Some(0u64) }, quote! { ::std::option::Option::Some(0u64) })
246  } else {
247    let last_modified = match file.metadata.last_modified() {
248      Some(last_modified) => quote! { ::std::option::Option::Some(#last_modified) },
249      None => quote! { ::std::option::Option::None },
250    };
251    let created = match file.metadata.created() {
252      Some(created) => quote! { ::std::option::Option::Some(#created) },
253      None => quote! { ::std::option::Option::None },
254    };
255    (last_modified, created)
256  };
257  #[cfg(feature = "mime-guess")]
258  let mimetype_tokens = {
259    let mt = file.metadata.mimetype();
260    quote! { , #mt }
261  };
262  #[cfg(not(feature = "mime-guess"))]
263  let mimetype_tokens = TokenStream2::new();
264
265  let embedding_code = if metadata_only {
266    quote! {
267        const BYTES: &'static [u8] = &[];
268    }
269  } else if cfg!(feature = "compression") {
270    let folder_path = folder_path.ok_or(syn::Error::new(ident.span(), "`folder` must be provided under `compression` feature."))?;
271    // Print some debugging information
272    let full_relative_path = PathBuf::from_iter([folder_path, rel_path]);
273    let full_relative_path = full_relative_path.to_string_lossy();
274    quote! {
275      #crate_path::flate!(static BYTES: [u8] from #full_relative_path);
276    }
277  } else {
278    quote! {
279      const BYTES: &'static [u8] = include_bytes!(#full_canonical_path);
280    }
281  };
282  let closure_args = if cfg!(feature = "compression") {
283    quote! { || }
284  } else {
285    quote! {}
286  };
287  Ok(quote! {
288       #closure_args {
289        #embedding_code
290
291        #crate_path::EmbeddedFile {
292            data: ::std::borrow::Cow::Borrowed(&BYTES),
293            metadata: #crate_path::Metadata::__rust_embed_new([#(#hash),*], #last_modified, #created #mimetype_tokens)
294        }
295      }
296  })
297}
298
299/// Find all pairs of the `name = "value"` attribute from the derive input
300fn find_attribute_values(ast: &syn::DeriveInput, attr_name: &str) -> Vec<String> {
301  ast
302    .attrs
303    .iter()
304    .filter(|value| value.path().is_ident(attr_name))
305    .filter_map(|attr| match &attr.meta {
306      Meta::NameValue(MetaNameValue {
307        value: Expr::Lit(ExprLit { lit: Lit::Str(val), .. }),
308        ..
309      }) => Some(val.value()),
310      _ => None,
311    })
312    .collect()
313}
314
315fn find_bool_attribute(ast: &syn::DeriveInput, attr_name: &str) -> Option<bool> {
316  ast
317    .attrs
318    .iter()
319    .find(|value| value.path().is_ident(attr_name))
320    .and_then(|attr| match &attr.meta {
321      Meta::NameValue(MetaNameValue {
322        value: Expr::Lit(ExprLit { lit: Lit::Bool(val), .. }),
323        ..
324      }) => Some(val.value()),
325      _ => None,
326    })
327}
328
329fn impl_rust_embed(ast: &syn::DeriveInput) -> syn::Result<TokenStream2> {
330  match ast.data {
331    Data::Struct(ref data) => match data.fields {
332      Fields::Unit => {}
333      _ => return Err(syn::Error::new_spanned(ast, "RustEmbed can only be derived for unit structs")),
334    },
335    _ => return Err(syn::Error::new_spanned(ast, "RustEmbed can only be derived for unit structs")),
336  };
337
338  let crate_path: syn::Path = find_attribute_values(ast, "crate_path")
339    .last()
340    .map_or_else(|| syn::parse_str("rust_embed").unwrap(), |v| syn::parse_str(v).unwrap());
341
342  let mut folder_paths = find_attribute_values(ast, "folder");
343  if folder_paths.len() != 1 {
344    return Err(syn::Error::new_spanned(
345      ast,
346      "#[derive(RustEmbed)] must contain one attribute like this #[folder = \"examples/public/\"]",
347    ));
348  }
349  let folder_path = folder_paths.remove(0);
350
351  let prefix = find_attribute_values(ast, "prefix").into_iter().next();
352  let includes = find_attribute_values(ast, "include");
353  let excludes = find_attribute_values(ast, "exclude");
354  let metadata_only = find_bool_attribute(ast, "metadata_only").unwrap_or(false);
355  let allow_missing = find_bool_attribute(ast, "allow_missing").unwrap_or(false);
356
357  #[cfg(not(feature = "include-exclude"))]
358  if !includes.is_empty() || !excludes.is_empty() {
359    return Err(syn::Error::new_spanned(
360      ast,
361      "Please turn on the `include-exclude` feature to use the `include` and `exclude` attributes",
362    ));
363  }
364
365  #[cfg(feature = "interpolate-folder-path")]
366  let folder_path = shellexpand::full(&folder_path)
367    .map_err(|v| syn::Error::new_spanned(ast, v.to_string()))?
368    .to_string();
369
370  // Base relative paths on the Cargo.toml location
371  let (relative_path, absolute_folder_path) = if Path::new(&folder_path).is_relative() {
372    let absolute_path = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap())
373      .join(&folder_path)
374      .to_str()
375      .unwrap()
376      .to_owned();
377    (Some(folder_path.clone()), absolute_path)
378  } else {
379    if cfg!(feature = "compression") {
380      return Err(syn::Error::new_spanned(ast, "`folder` must be a relative path under `compression` feature."));
381    }
382    (None, folder_path)
383  };
384
385  if !Path::new(&absolute_folder_path).exists() && !allow_missing {
386    let mut message = format!(
387      "#[derive(RustEmbed)] folder '{}' does not exist. cwd: '{}'",
388      absolute_folder_path,
389      std::env::current_dir().unwrap().to_str().unwrap()
390    );
391
392    // Add a message about the interpolate-folder-path feature if the path may
393    // include a variable
394    if absolute_folder_path.contains('$') && cfg!(not(feature = "interpolate-folder-path")) {
395      message += "\nA variable has been detected. RustEmbed can expand variables \
396                  when the `interpolate-folder-path` feature is enabled.";
397    }
398
399    return Err(syn::Error::new_spanned(ast, message));
400  };
401
402  generate_assets(
403    &ast.ident,
404    relative_path.as_deref(),
405    absolute_folder_path,
406    prefix,
407    includes,
408    excludes,
409    metadata_only,
410    &crate_path,
411  )
412}
413
414#[proc_macro_derive(RustEmbed, attributes(folder, prefix, include, exclude, allow_missing, metadata_only, crate_path))]
415pub fn derive_input_object(input: TokenStream) -> TokenStream {
416  let ast = parse_macro_input!(input as DeriveInput);
417  match impl_rust_embed(&ast) {
418    Ok(ok) => ok.into(),
419    Err(e) => e.to_compile_error().into(),
420  }
421}