Skip to main content

tower_embed_impl/
lib.rs

1use std::borrow::Cow;
2
3use camino::{Utf8Path as Path, Utf8PathBuf as PathBuf};
4use quote::ToTokens;
5use tower_embed_core::headers;
6
7/// Derive the `Embed` trait for unit struct, embedding assets from a folder.
8///
9/// ## Usage
10///
11/// Apply `#[derive(Embed)]` to a unit struct and specify the folder to embed using the
12/// `#[embed(folder = "...")]` attribute.
13///
14/// Optionally, specify the crate path with `#[embed(crate = path)]`. This is applicable when
15/// invoking re-exported derive from a public macro in a different crate.
16///
17/// The name of file to serve as index for directories can be customized using #[embed(index =
18/// "...")], the default is "index.html".
19#[proc_macro_derive(Embed, attributes(embed))]
20pub fn derive_embed(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let input = syn::parse_macro_input!(input as syn::DeriveInput);
22
23    expand_derive_embed(input)
24        .unwrap_or_else(|err| err.to_compile_error())
25        .into()
26}
27
28fn expand_derive_embed(input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
29    let DeriveEmbed { ident, attrs } = DeriveEmbed::from_ast(&input)?;
30    let DeriveEmbedAttrs {
31        folder,
32        crate_path,
33        index,
34    } = attrs;
35
36    let root = root_absolute_path(&folder);
37    let embedded_files = get_files(&root, &index).map(|file| {
38        let last_modified = tower_embed_core::last_modified(file.absolute_path.as_std_path())
39            .ok()
40            .and_then(|headers::LastModified(time)| {
41                time.duration_since(std::time::UNIX_EPOCH)
42                    .map(|duration| duration.as_secs())
43                    .ok()
44            });
45        let last_modified = match last_modified {
46            Some(secs) => quote::quote! { headers::LastModified::from_unix_timestamp(#secs) },
47            None => quote::quote! { None },
48        };
49
50        let relative_path = file.relative_path.as_str();
51        let absolute_path = file.absolute_path.as_str();
52        let redirect_path = format!("{relative_path}/{index}");
53        let redirect_path = redirect_path.trim_start_matches('/');
54
55        match file.kind {
56            FileKind::File => quote::quote! {{
57                let content = include_bytes!(#absolute_path).as_slice();
58                let metadata = Metadata {
59                    content_type: #crate_path::core::content_type(Path::new(#relative_path)),
60                    etag: Some(#crate_path::core::etag(content)),
61                    last_modified: #last_modified,
62                };
63                [(#relative_path, Entry::File(content, metadata))]
64            }},
65            FileKind::Dir => quote::quote! {{
66                [
67                    (#relative_path, Entry::Redirect(#redirect_path)),
68                    (concat!(#relative_path, "/"), Entry::Redirect(#redirect_path)),
69                ]
70            }},
71        }
72    });
73
74    let root = root.as_str();
75
76    let expanded = quote::quote! {
77        impl #crate_path::Embed for #ident {
78            #[cfg(not(debug_assertions))]
79            fn get(path: &str) -> impl Future<Output = std::io::Result<#crate_path::core::Embedded>> + Send + 'static {
80                use std::{collections::HashMap, sync::LazyLock, path::Path};
81
82                use #crate_path::core::{Content, Embedded, Metadata, headers};
83
84                enum Entry {
85                    File(&'static [u8], Metadata),
86                    Redirect(&'static str),
87                }
88
89                const FILES: LazyLock<HashMap<&'static str, Entry>> = LazyLock::new(|| {
90                    let mut m = HashMap::new();
91                    #(m.extend(#embedded_files);)*
92                    m
93                });
94
95                let mut path = path;
96                let output = loop {
97                    match FILES.get(path) {
98                        Some(Entry::File(bytes, metadata)) => break Ok(Embedded {
99                            content: Content::from_static(bytes),
100                            metadata: metadata.clone(),
101                        }),
102                        Some(Entry::Redirect(redirect_path)) => {
103                            path = redirect_path;
104                        }
105                        None => break Err(std::io::ErrorKind::NotFound.into()),
106                    };
107                };
108                std::future::ready(output)
109            }
110
111            #[cfg(debug_assertions)]
112            fn get(path: &str) -> impl Future<Output = std::io::Result<#crate_path::core::Embedded>> + Send + 'static {
113                use std::path::Path;
114
115                use #crate_path::core::{Content, Embedded, Metadata};
116
117                const ROOT: &str = #root;
118
119                let mut filename = Path::new(ROOT).join(path);
120                let stripped_path = Path::new(ROOT).join(path.trim_end_matches('/'));
121                if stripped_path.is_dir() {
122                    filename = filename.join(#index);
123                }
124
125                let metadata = Metadata {
126                    content_type: #crate_path::core::content_type(&filename),
127                    etag: None,
128                    last_modified: None,
129                };
130
131                async move {
132                    #crate_path::file::File::open(&filename).await.map(|file| {
133                        Embedded {
134                            content: Content::from_stream(file),
135                            metadata,
136                        }
137                    })
138                }
139            }
140        }
141    };
142
143    Ok(expanded)
144}
145
146/// A source data annotated with `#[derive(Embed)]``
147struct DeriveEmbed {
148    /// The struct name
149    ident: syn::Ident,
150    /// Attributes of structure
151    attrs: DeriveEmbedAttrs,
152}
153
154/// Attributes for `Embed` derive macro.
155struct DeriveEmbedAttrs {
156    /// The folder to embed
157    folder: String,
158    /// The path to the crate `tower_embed`
159    crate_path: syn::Path,
160    /// The index file name
161    index: Cow<'static, str>,
162}
163
164impl DeriveEmbed {
165    fn from_ast(input: &syn::DeriveInput) -> syn::Result<Self> {
166        let syn::Data::Struct(data) = &input.data else {
167            return Err(syn::Error::new_spanned(
168                input,
169                "`Embed` can only be derived for unit structs",
170            ));
171        };
172
173        if !matches!(&data.fields, syn::Fields::Unit) {
174            return Err(syn::Error::new_spanned(
175                &data.fields,
176                "`Embed` can only be derived for unit structs",
177            ));
178        }
179
180        let ident = input.ident.clone();
181        let attrs = DeriveEmbedAttrs::from_ast(input)?;
182
183        Ok(Self { ident, attrs })
184    }
185}
186
187impl DeriveEmbedAttrs {
188    fn from_ast(input: &syn::DeriveInput) -> syn::Result<Self> {
189        let mut folder = None;
190        let mut crate_path = None;
191        let mut index = None;
192
193        for attr in &input.attrs {
194            if !attr.path().is_ident("embed") {
195                continue;
196            }
197
198            let list = attr.meta.require_list()?;
199            if list.tokens.is_empty() {
200                continue;
201            }
202
203            list.parse_nested_meta(|meta| {
204                if meta.path.is_ident("folder") {
205                    let value: syn::LitStr = meta.value()?.parse()?;
206                    folder = Some(value.value());
207                } else if meta.path.is_ident("crate") {
208                    let value: syn::Path = meta.value()?.parse()?;
209                    crate_path = Some(value);
210                } else if meta.path.is_ident("index") {
211                    let value: syn::LitStr = meta.value()?.parse()?;
212                    index = Some(Cow::Owned(value.value()));
213                } else {
214                    let name = meta.path.to_token_stream();
215                    return Err(syn::Error::new_spanned(
216                        meta.path,
217                        format_args!("unknown `embed` attribute for `{}`", name),
218                    ));
219                }
220                Ok(())
221            })?;
222        }
223
224        let Some(folder) = folder else {
225            return Err(syn::Error::new_spanned(
226                input,
227                "#[derive(Embed)] requires `folder` attribute",
228            ));
229        };
230
231        let crate_path = crate_path.unwrap_or_else(|| syn::parse_quote! { tower_embed });
232        let index = index.unwrap_or(Cow::Borrowed("index.html"));
233
234        Ok(Self {
235            folder,
236            crate_path,
237            index,
238        })
239    }
240}
241
242fn root_absolute_path(folder: &str) -> PathBuf {
243    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
244        .expect("missing CARGO_MANIFEST_DIR environment variable");
245
246    Path::new(&manifest_dir).join(folder)
247}
248
249fn get_files(root: &Path, index: &str) -> impl Iterator<Item = File> {
250    walkdir::WalkDir::new(root)
251        .follow_links(true)
252        .sort_by_file_name()
253        .into_iter()
254        .filter_map(Result::ok)
255        .filter_map(move |entry| {
256            let kind = if entry.file_type().is_file() {
257                FileKind::File
258            } else if entry.file_type().is_dir() {
259                if !entry.path().join(index).is_file() {
260                    return None;
261                }
262
263                FileKind::Dir
264            } else {
265                return None;
266            };
267
268            let absolute_path: &Path = entry.path().try_into().unwrap();
269            let absolute_path = absolute_path.to_path_buf();
270
271            let relative_path = absolute_path
272                .canonicalize_utf8()
273                .unwrap()
274                .strip_prefix(root)
275                .unwrap()
276                .to_path_buf();
277
278            Some(File {
279                kind,
280                relative_path,
281                absolute_path,
282            })
283        })
284}
285
286struct File {
287    kind: FileKind,
288    relative_path: PathBuf,
289    absolute_path: PathBuf,
290}
291
292enum FileKind {
293    File,
294    Dir,
295}