rz_embed/
lib.rs

1extern crate proc_macro;
2use std::path::PathBuf;
3
4use lazy_static::lazy_static;
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input, Expr, Ident, LitStr, Token,
11};
12
13use flate2::write::GzEncoder;
14use flate2::Compression;
15use regex::Regex;
16use std::fs::{self, create_dir_all, File};
17use std::io::{self, BufReader, BufWriter, Read};
18use std::path::Path;
19use walkdir::WalkDir;
20
21lazy_static! {
22    static ref RE_NON_ALPHANUMERIC: Regex = Regex::new(r"[^\w\s-]").unwrap();
23    static ref RE_WHITESPACE_HYPHEN: Regex = Regex::new(r"[-\s]+").unwrap();
24    static ref RE_REDUCE_UNDESCORES: Regex = Regex::new(r"_+").unwrap();
25}
26
27fn slugify(value: &str) -> String {
28    // Replace non-alphanumeric characters with underscores
29    let mut value = RE_NON_ALPHANUMERIC.replace_all(value, "_").to_string();
30    // Trim and convert to lowercase
31    value = value.trim().to_lowercase();
32    // Replace whitespace and hyphens with underscores
33    value = RE_WHITESPACE_HYPHEN.replace_all(&value, "_").to_string();
34    // Reduce multiple underscores to one
35    value = RE_REDUCE_UNDESCORES.replace_all(&value, "_").to_string();
36    // Remove leading underscore if present
37    if value.starts_with('_') {
38        value = value[1..].to_string();
39    }
40
41    value
42}
43
44#[derive(Debug)]
45enum ContentType {
46    Unknown,
47    Png,
48    Ttf,
49    Ico,
50}
51
52impl ContentType {
53    pub fn from_extension(ext: &str) -> Self {
54        match ext {
55            "png" => Self::Png,
56            "ttf" => Self::Ttf,
57            "ico" => Self::Ico,
58            _ => Self::Unknown,
59        }
60    }
61}
62
63#[derive(Debug)]
64enum FileType {
65    Html,
66    JavaScript,
67    Css,
68    Json,
69    Xml,
70    Plain,
71    Binary(ContentType),
72}
73
74impl FileType {
75    pub fn from_extension(ext: &Option<String>) -> Self {
76        let ext = ext.as_ref().map(|e| e.as_str());
77        match ext {
78            Some("html") => FileType::Html,
79            Some("js") => FileType::JavaScript,
80            Some("css") => FileType::Css,
81            Some("json") => FileType::Json,
82            Some("xml") => FileType::Xml,
83            Some("txt") | Some("md") => FileType::Plain,
84            //
85            Some(other) => FileType::Binary(ContentType::from_extension(other)),
86            None => FileType::Binary(ContentType::Unknown),
87        }
88    }
89}
90
91struct ResourceFile {
92    pub path: PathBuf,
93    pub slug: String,
94    pub const_name: String,
95    pub file_type: FileType,
96}
97
98impl std::fmt::Display for ResourceFile {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(f, "{:?} -- {:?}", self.path, self.file_type)
101    }
102}
103
104impl ResourceFile {
105    /// Create a ResourceFile from a path relative to the source directory
106    pub fn from_path(rel_path: &Path) -> Self {
107        let name = rel_path
108            .file_name()
109            .expect("Failed to get file name")
110            .to_str()
111            .expect("Failed to convert file name to str");
112        let path_str = rel_path.to_str().expect("Failed to convert path to str");
113        let extension = name.split('.').last().map(|ext| ext.to_string());
114        let file_type = FileType::from_extension(&extension);
115        let slug = slugify(path_str);
116        let const_name = slug.to_ascii_uppercase();
117        Self {
118            path: rel_path.to_owned(),
119            slug,
120            const_name,
121            file_type,
122        }
123    }
124
125    pub fn collect(root_dir: &PathBuf) -> Vec<ResourceFile> {
126        let mut result = Vec::new();
127        for entry in WalkDir::new(root_dir).into_iter().filter_map(Result::ok) {
128            let path = entry.path();
129            if path.is_file() {
130                let relative_path = path.strip_prefix(root_dir).expect("path error");
131                let resource = ResourceFile::from_path(relative_path);
132                result.push(resource);
133            }
134        }
135        result
136    }
137    //
138}
139
140fn compress_resource(src: &Path, dst: &Path, r: &ResourceFile) -> (u64, u64) {
141    let src = src.join(&r.path);
142    let meta_original = fs::metadata(&src).expect("Failed to get file metadata");
143    let compressed_path = dst.join(format!("{}.gz", r.slug));
144
145    // Check if the compressed file exists, return early if the original was not
146    // modified since compression occured.
147    if let Ok(compressed_metadata) = fs::metadata(&compressed_path) {
148        if meta_original
149            .modified()
150            .expect("Failed to get modified time")
151            <= compressed_metadata
152                .modified()
153                .expect("Failed to get modified time")
154        {
155            let original_sz = meta_original.len();
156            let compressed_sz = compressed_metadata.len();
157            println!(
158                "[~] {} already compressed {} -> {} bytes ({:.2}%)",
159                src.display(),
160                original_sz,
161                compressed_sz,
162                calculate_compression_rate(original_sz, compressed_sz)
163            );
164            return (meta_original.len(), compressed_metadata.len());
165        }
166    }
167
168    // Open the source file for reading
169    let f_in = File::open(&src).expect("Failed to open source file");
170    let reader = BufReader::new(f_in);
171
172    // Open the destination file for writing
173    let f_out = File::create(&compressed_path).expect("Failed to create compressed file");
174    let writer = BufWriter::new(f_out);
175
176    // Create a GzEncoder to compress the data
177    let mut encoder = GzEncoder::new(writer, Compression::default());
178
179    // Compress the file
180    io::copy(&mut reader.take(u64::MAX), &mut encoder).expect("Read failed");
181    encoder.finish().expect("Compression failed");
182
183    let meta_compressed = fs::metadata(compressed_path).expect("Failed to get metadata");
184    let original_sz = meta_original.len();
185    let compressed_sz = meta_compressed.len();
186    println!(
187        "[+] {}: {} -> {} bytes ({:.2}%)",
188        src.display(),
189        original_sz,
190        compressed_sz,
191        calculate_compression_rate(original_sz, compressed_sz)
192    );
193
194    (original_sz, compressed_sz)
195}
196
197fn calculate_compression_rate(original_size: u64, compressed_size: u64) -> f64 {
198    if original_size == 0 {
199        return 0.0;
200    }
201    let rate = 1.0 - (compressed_size as f64 / original_size as f64);
202    rate * 100.0
203}
204
205fn compress_resources(src: &Path, gz: &Path, resources: &Vec<ResourceFile>) -> (u64, u64) {
206    let mut total_original_sz = 0_u64;
207    let mut total_compressed_sz = 0_u64;
208    create_dir_all(gz).expect("Failed to create gz directory");
209    for r in resources {
210        let (orig, reduced) = compress_resource(src, gz, r);
211        total_original_sz += orig;
212        total_compressed_sz += reduced;
213    }
214    println!(
215        "[*] total: {} -> {} bytes ({:.2}%)",
216        total_original_sz,
217        total_compressed_sz,
218        calculate_compression_rate(total_original_sz, total_compressed_sz)
219    );
220    (total_original_sz, total_compressed_sz)
221}
222
223//
224
225struct InclAsCompressedArgs {
226    folder_path: LitStr,
227    module_name: syn::Ident,
228    rocket: bool,
229}
230
231impl Parse for InclAsCompressedArgs {
232    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
233        let folder_path: LitStr = input.parse()?;
234        input.parse::<Token![,]>()?;
235
236        let mut module_name = None;
237        let mut rocket = false;
238
239        while !input.is_empty() {
240            let ident: Ident = input.parse()?;
241            input.parse::<Token![=]>()?;
242
243            match ident.to_string().as_str() {
244                "module_name" => {
245                    let name: LitStr = input.parse()?;
246                    module_name = Some(Ident::new(&name.value(), name.span()));
247                }
248                "rocket" => {
249                    let pub_expr: Expr = input.parse()?;
250                    if let Expr::Lit(expr_lit) = pub_expr {
251                        if let syn::Lit::Bool(lit_bool) = expr_lit.lit {
252                            rocket = lit_bool.value;
253                        }
254                    }
255                }
256                _ => return Err(syn::Error::new(ident.span(), "Unexpected parameter")),
257            }
258
259            if input.peek(Token![,]) {
260                input.parse::<Token![,]>()?;
261            }
262        }
263
264        Ok(InclAsCompressedArgs {
265            folder_path: folder_path.clone(),
266            module_name: module_name.unwrap_or_else(|| Ident::new("embedded", folder_path.span())),
267            rocket,
268        })
269    }
270}
271
272fn generate_rocket_code(resources: &Vec<ResourceFile>) -> proc_macro2::TokenStream {
273    let mut generated_code = quote! {};
274    let mut route_idents = Vec::<syn::Ident>::new();
275
276    if resources
277        .iter()
278        .any(|r| matches!(r.file_type, FileType::Binary(_)))
279    {
280        generated_code = quote! {
281        #generated_code
282
283        pub struct BinaryResponse(&'static [u8], rocket::http::ContentType);
284        impl<'r> rocket::response::Responder<'r, 'static> for BinaryResponse {
285            fn respond_to(self, _: &'r rocket::request::Request<'_>) -> rocket::response::Result<'static> {
286                rocket::response::Response::build()
287                    .header(self.1)
288                    .sized_body(self.0.len(), std::io::Cursor::new(self.0))
289                    .ok()
290            }
291        }
292                    };
293    }
294
295    for res in resources {
296        let const_name = syn::Ident::new(&res.const_name, Span::call_site());
297        let mut handler_url = String::from("/");
298        handler_url.push_str(res.path.to_str().expect("path to_str failed"));
299        let handler_name = syn::Ident::new(
300            &format!("serve_{}", slugify(&handler_url)),
301            Span::call_site(),
302        );
303        route_idents.push(handler_name.clone());
304
305        let handler_code = match &res.file_type {
306            FileType::Html => quote! {
307                #[get(#handler_url)]
308                pub fn #handler_name() -> rocket::response::content::RawHtml<&'static [u8]> {
309                    rocket::response::content::RawHtml(&#const_name)
310                }
311            },
312            FileType::JavaScript => quote! {
313                #[get(#handler_url)]
314                pub fn #handler_name() -> rocket::response::content::RawJavaScript<&'static [u8]> {
315                    rocket::response::content::RawJavaScript(&#const_name)
316                }
317            },
318            FileType::Css => quote! {
319                #[get(#handler_url)]
320                pub fn #handler_name() -> rocket::response::content::RawCss<&'static [u8]> {
321                    rocket::response::content::RawCss(&#const_name)
322                }
323            },
324            FileType::Json => quote! {
325                #[get(#handler_url)]
326                pub fn #handler_name() -> rocket::response::content::RawJson<&'static [u8]> {
327                    rocket::response::content::RawJson(&#const_name)
328                }
329            },
330            FileType::Xml => quote! {
331                #[get(#handler_url)]
332                pub fn #handler_name() -> rocket::response::content::RawXml<&'static [u8]> {
333                    rocket::response::content::RawXml(&#const_name)
334                }
335            },
336            FileType::Plain => quote! {
337                #[get(#handler_url)]
338                pub fn #handler_name() -> rocket::response::content::RawText<&'static [u8]> {
339                    rocket::response::content::RawText(&#const_name)
340                }
341            },
342
343            FileType::Binary(content_type) => match content_type {
344                ContentType::Unknown => quote! {
345                    #[get(#handler_url)]
346                    pub fn #handler_name() -> BinaryResponse {
347                        BinaryResponse(&#const_name, rocket::http::ContentType::Binary)
348                    }
349                },
350                ContentType::Png => quote! {
351                    #[get(#handler_url)]
352                    pub fn #handler_name() -> BinaryResponse {
353                        BinaryResponse(&#const_name, rocket::http::ContentType::PNG)
354                    }
355                },
356                ContentType::Ttf => quote! {
357                    #[get(#handler_url)]
358                    pub fn #handler_name() -> BinaryResponse {
359                        BinaryResponse(&#const_name,  rocket::http::ContentType::TTF)
360                    }
361                },
362                ContentType::Ico => quote! {
363                    #[get(#handler_url)]
364                    pub fn #handler_name() -> BinaryResponse {
365                        BinaryResponse(&#const_name,   rocket::http::ContentType::Icon)
366                    }
367                },
368            },
369        };
370
371        generated_code = quote! {
372            #generated_code
373            #handler_code
374        };
375        // Add an additional route for Html files without the .html extension
376        if matches!(res.file_type, FileType::Html) {
377            let mut handler_url = PathBuf::from("/");
378            handler_url.push(res.path.to_str().expect("path to_str failed"));
379            handler_url.set_extension("");
380            let handler_url = handler_url.to_str().unwrap();
381            let handler_name = syn::Ident::new(
382                &format!("serve_{}", slugify(handler_url)),
383                Span::call_site(),
384            );
385            route_idents.push(handler_name.clone());
386            let handler = quote! {
387                #[get(#handler_url)]
388                pub fn #handler_name() -> rocket::response::content::RawHtml<&'static [u8]> {
389                    rocket::response::content::RawHtml(&#const_name)
390                }
391            };
392            generated_code = quote! {
393                #generated_code
394                #handler
395            };
396        }
397    }
398    let routes_collector = quote! {
399        pub fn routes() -> Vec<rocket::Route> {
400            routes![#(#route_idents),*]
401        }
402    };
403
404    generated_code = quote! {
405        #generated_code
406        #routes_collector
407    };
408    generated_code
409}
410
411#[proc_macro]
412pub fn include_as_compressed(input: TokenStream) -> TokenStream {
413    // Parse input
414    let args = parse_macro_input!(input as InclAsCompressedArgs);
415    let module_name = args.module_name;
416    // Input paths should behave like include_bytes
417    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
418    let folder_path = Path::new(&manifest_dir).join(args.folder_path.value());
419    let folder_str = folder_path
420        .to_str()
421        .expect("Failed to get str from folder_path");
422    let resources = ResourceFile::collect(&folder_path);
423    let input_slug = slugify(folder_str);
424    // GZ dir should conform to crate conventions
425    let gz_dir = PathBuf::from(&manifest_dir)
426        .join("target")
427        .join("rz-embed")
428        .join(input_slug);
429
430    compress_resources(&folder_path, &gz_dir, &resources);
431
432    let mut generated_code = quote! {
433        use lazy_static::lazy_static;
434        use flate2::read::GzDecoder;
435        use std::io::Read;
436    };
437
438    // Generate code for each resource
439    for res in &resources {
440        let name = syn::Ident::new(&res.const_name, Span::call_site());
441        let compressed_source_path = gz_dir.join(format!("{}.gz", res.slug));
442        let compressed_source_path_str = compressed_source_path.to_str().unwrap();
443        let res_code = quote! {
444            lazy_static! {
445                pub static ref #name: Vec<u8> = {
446                    let compressed_data: &[u8] = include_bytes!(#compressed_source_path_str);
447                    let mut decoder = GzDecoder::new(compressed_data);
448                    let mut decompressed_data = Vec::new();
449                    decoder.read_to_end(&mut decompressed_data).unwrap();
450                    decompressed_data
451                };
452            }
453        };
454        generated_code = quote! {
455            #generated_code
456            #res_code
457        };
458    }
459
460    // Function to restore to disk
461    let mut store_to_disk_fn_body = quote! {};
462    for res in &resources {
463        let name = syn::Ident::new(&res.const_name, Span::call_site());
464        let path = syn::LitStr::new(&res.path.to_string_lossy(), Span::call_site());
465        let parts = quote! {
466            {
467                let path = dst.join(#path);
468                let parent = path.parent().expect("Failed to get parent: {path:?}");
469                if !parent.is_dir() {
470                    std::fs::create_dir_all(parent)?;
471                }
472                let mut file_handle = std::fs::File::create(path)?;
473                std::io::Write::write_all(&mut file_handle, &#name)?;
474            }
475        };
476        store_to_disk_fn_body = quote! {
477            #store_to_disk_fn_body
478            #parts
479        };
480    }
481    let store_to_disk_fn = quote! {
482        pub fn extract_to_folder(dst: &std::path::Path) -> std::result::Result<(), std::io::Error> {
483            #store_to_disk_fn_body
484            Ok(())
485        }
486    };
487    generated_code = quote! {
488        #generated_code
489        #store_to_disk_fn
490    };
491
492    if args.rocket {
493        let rocket_code = generate_rocket_code(&resources);
494        generated_code = quote! {
495            #generated_code
496            #rocket_code
497        };
498    }
499
500    let result = quote! {
501        mod #module_name {
502            #generated_code
503        }
504    };
505
506    result.into()
507}
508
509
510mod tests {
511    #[test]
512    fn test_slugify() {
513        // just a sanity check - maybe we should further limit this to prevent "uncommon code points"
514        assert_eq!(
515            super::slugify(
516                "f0o/b$r/b🇺🇳z/!\"§$%&()=?`''¹²³¼½¬{[]}\\¸ÜÄäü*':;.,@ł€¶ŧ←↓→øþ¨~»«¢„“”µ·…txt"
517            ),
518            "f0o_b_r_b_z_üääü_ł_ŧ_øþ_µ_txt"
519        );
520        assert_eq!(super::slugify("a______b"), "a_b");
521    }
522}