webserver_base/
cache_buster.rs

1use std::{
2    collections::BTreeMap,
3    fmt::{self, Display},
4    fs::{self, File},
5    io::Read,
6};
7use std::{collections::VecDeque, path::Path};
8use std::{fs::DirEntry, path::PathBuf};
9
10use axum::{
11    body::Body,
12    extract::Request,
13    http::{HeaderMap, HeaderValue},
14    middleware::Next,
15    response::Response,
16};
17use chrono::{DateTime, Duration, TimeDelta, Utc};
18use regex::Regex;
19use reqwest::{
20    StatusCode,
21    header::{
22        CACHE_CONTROL, ETAG, EXPIRES, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_RANGE,
23        IF_UNMODIFIED_SINCE, PRAGMA,
24    },
25};
26use tracing::{error, instrument, warn};
27
28#[derive(Debug, Clone)]
29pub struct CacheBuster {
30    asset_directory: String,
31
32    cache: BTreeMap<String, String>,
33}
34
35impl CacheBuster {
36    #[must_use]
37    #[instrument(skip_all)]
38    pub fn new(asset_directory: &str) -> Self {
39        Self {
40            asset_directory: asset_directory.to_string(),
41            cache: BTreeMap::new(),
42        }
43    }
44
45    #[instrument(skip_all)]
46    pub fn gen_cache(&mut self) {
47        self.cache = gen_cache(Path::new(&self.asset_directory));
48    }
49
50    /// Takes a path from root domain to a static asset (as it would be called from a browser, so with a leading slash)
51    /// and returns the version of the filepath that contains a unique hash.
52    ///
53    /// e.g. "/static/image/favicon/favicon.ico" -> "/static/image/favicon/favicon.66189abc248d80832e458ee37e93c9e8.ico"
54    ///
55    /// # Panics
56    ///
57    /// Panics if the file is not found in the cache.
58    #[must_use]
59    #[instrument(skip_all)]
60    pub fn get_file(&self, original_asset_file_path: &str) -> String {
61        // return the original path if the path does not start with the asset directory
62        if !original_asset_file_path.starts_with(&self.asset_directory) {
63            warn!(
64                "CacheBuster: File path does not start with asset directory: '{original_asset_file_path:?}'. Returning original path: '{original_asset_file_path:?}'."
65            );
66            return original_asset_file_path.to_string();
67        }
68
69        self.cache
70            .get(original_asset_file_path)
71            .cloned()
72            .unwrap_or_else(|| {
73                error!(
74                    "CacheBuster: File not found in cache: '{original_asset_file_path:?}'. Returning original path."
75                );
76                original_asset_file_path.to_string()
77            })
78    }
79
80    #[must_use]
81    #[instrument(skip_all)]
82    pub fn get_cache(&self) -> BTreeMap<String, String> {
83        self.cache.clone()
84    }
85
86    /// # Panics
87    ///
88    /// Panics if the file cannot be created or written to.
89    #[instrument(skip_all)]
90    pub fn print_to_file(&self, output_dir: &str) {
91        let output_path: PathBuf = Path::new(output_dir).join("cache-buster.json");
92        let file: File = File::create(&output_path)
93            .unwrap_or_else(|_| panic!("Failed to create file: {}", output_path.display()));
94
95        serde_json::to_writer_pretty(file, &self.cache)
96            .unwrap_or_else(|_| panic!("Failed to write JSON to file: {}", output_path.display()));
97    }
98
99    /// Updates the sourceMappingURL comment in `.js` files to point to the hashed `.js.map` file.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the file cannot be read or parsed.
104    #[instrument(skip_all)]
105    pub fn update_source_map_references(&self) {
106        let source_map_regex: Regex = Regex::new(r"//# sourceMappingURL=(.+\.js\.map)")
107            .unwrap_or_else(|_| panic!("Failed to compile sourceMappingURL regex"));
108
109        for (original_path, hashed_path) in &self.cache {
110            // only process `.js` files
111            if !std::path::Path::new(original_path)
112                .extension()
113                .is_some_and(|ext| ext.eq_ignore_ascii_case("js"))
114            {
115                continue;
116            }
117
118            // check for corresponding `.map` file
119            let original_map_path: String = format!("{original_path}.map");
120            let hashed_map_path: &String = match self.cache.get(&original_map_path) {
121                Some(path) => path,
122                None => continue,
123            };
124
125            // read `.map` file content
126            let mut content: String = fs::read_to_string(hashed_path)
127                .unwrap_or_else(|_| panic!("Failed to read file: {hashed_path}"));
128
129            // get just the `.map` filename
130            let hashed_map_filename: &str = Path::new(hashed_map_path)
131                .file_name()
132                .and_then(|s| s.to_str())
133                .unwrap_or_else(|| panic!("Invalid hashed map path"));
134
135            // replace the `sourceMappingURL` comment
136            if source_map_regex.is_match(&content) {
137                content = source_map_regex
138                    .replace(
139                        &content,
140                        format!("//# sourceMappingURL={hashed_map_filename}"),
141                    )
142                    .into_owned();
143
144                // Write the updated content back to the file
145                fs::write(hashed_path, content)
146                    .unwrap_or_else(|_| panic!("Failed to write file: {hashed_path}"));
147            }
148        }
149    }
150
151    /// Middleware to set never-cache headers for all responses.
152    ///
153    /// # Errors
154    ///
155    /// Will return `Error` if the request cannot be processed.
156    #[instrument(skip_all)]
157    pub async fn never_cache_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
158        let mut response: Response<Body> = next.run(req).await;
159
160        // remove ETag-related headers from the request
161        remove_etag_headers(response.headers_mut());
162
163        // set never-cache headers
164        response.headers_mut().insert(
165            EXPIRES,
166            HeaderValue::from_static("Thu, 01 Jan 1970 00:00:00 GMT"),
167        );
168        response.headers_mut().insert(
169            CACHE_CONTROL,
170            HeaderValue::from_static("no-cache, no-store, must-revalidate, private, max-age=0"),
171        );
172        response
173            .headers_mut()
174            .insert(PRAGMA, HeaderValue::from_static("no-cache"));
175
176        Ok(response)
177    }
178
179    /// Middleware to set forever cache headers for all responses.
180    ///
181    /// # Errors
182    ///
183    /// Will return `Error` if the request cannot be processed.
184    ///
185    /// # Panics
186    ///
187    /// Panics if the request cannot be processed.
188    #[instrument(skip_all)]
189    pub async fn forever_cache_middleware(
190        req: Request,
191        next: Next,
192    ) -> Result<Response, StatusCode> {
193        warn!(
194            "CacheBuster: Forever-cacheing resource: '{}'",
195            req.uri().path()
196        );
197        let mut response: Response<Body> = next.run(req).await;
198
199        // remove ETag-related headers from the request
200        remove_etag_headers(response.headers_mut());
201
202        // set forever-cache headers (1 year)
203        let one_year: TimeDelta = Duration::days(365);
204        let expires: DateTime<Utc> = Utc::now() + one_year;
205        response.headers_mut().insert(
206            EXPIRES,
207            HeaderValue::from_str(&expires.to_rfc2822()).unwrap(),
208        );
209        response.headers_mut().insert(
210            CACHE_CONTROL,
211            HeaderValue::from_static("public, max-age=31536000, must-revalidate, immutable"),
212        );
213
214        Ok(response)
215    }
216}
217
218impl Display for CacheBuster {
219    #[instrument(skip_all)]
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        // sort alphabetically by key
222        let mut keys: Vec<&String> = self.cache.keys().collect();
223        keys.sort();
224
225        write!(
226            f,
227            "CacheBuster (asset directory: '{}'):",
228            self.asset_directory
229        )?;
230        for key in keys {
231            write!(f, "\n\t'{}' -> '{}'", key, self.cache.get(key).unwrap())?;
232        }
233        Ok(())
234    }
235}
236
237#[instrument(skip_all)]
238fn gen_cache(root: &Path) -> BTreeMap<String, String> {
239    let mut cache: BTreeMap<String, String> = BTreeMap::new();
240
241    let mut dirs_to_visit: VecDeque<PathBuf> = VecDeque::new();
242    dirs_to_visit.push_back(root.to_path_buf());
243    while let Some(dir_path) = dirs_to_visit.pop_front() {
244        for entry in fs::read_dir(&dir_path)
245            .unwrap_or_else(|_| panic!("Failed to read directory: {}", dir_path.display()))
246        {
247            let error_msg: String = format!(
248                "Failed to read directory entry: {} -> {entry:?}",
249                dir_path.display(),
250            );
251            let entry: DirEntry = entry.expect(&error_msg);
252            let path: PathBuf = entry.path();
253
254            if path.is_dir() {
255                dirs_to_visit.push_back(path);
256            } else {
257                let original_file_path: String = path.to_string_lossy().to_string();
258                let new_file_path: String = generate_cache_busted_path(&path, root)
259                    .to_string_lossy()
260                    .to_string();
261
262                // rename the files on disk
263                fs::rename(&original_file_path, &new_file_path).unwrap_or_else(|_| {
264                    panic!("Failed to rename file: {original_file_path} -> {new_file_path}")
265                });
266
267                cache.insert(original_file_path, new_file_path);
268            }
269        }
270    }
271
272    cache
273}
274
275#[instrument(skip_all)]
276fn generate_cache_busted_path(file_path: &Path, root: &Path) -> PathBuf {
277    // read the file contents
278    let mut file: File = File::open(file_path).unwrap_or_else(|_| {
279        panic!(
280            "Failed to open file: {} -> {}",
281            root.display(),
282            file_path.display()
283        )
284    });
285    let mut contents: Vec<u8> = Vec::new();
286    file.read_to_end(&mut contents).unwrap_or_else(|_| {
287        panic!(
288            "Failed to read file: {} -> {}",
289            root.display(),
290            file_path.display()
291        )
292    });
293
294    // generate MD5 hash
295    let hash: String = format!("{:x}", md5::compute(contents));
296
297    // get the relative path components
298    let relative_path: &Path = file_path.strip_prefix(root).unwrap_or(file_path);
299    let parent: &Path = relative_path.parent().unwrap_or_else(|| Path::new(""));
300    let file_name: &str = relative_path
301        .file_name()
302        .and_then(|s| s.to_str())
303        .unwrap_or("");
304
305    let new_filename: String = if file_name.contains('.') {
306        // if at least one extension, insert hash before first period
307        let (name, rest) = file_name.split_once('.').unwrap();
308        format!("{name}.{hash}.{rest}")
309    } else {
310        // if no extension, append hash at the end
311        format!("{file_name}.{hash}")
312    };
313
314    // Combine with parent path and root
315    root.join(parent).join(new_filename)
316}
317
318#[instrument(skip_all)]
319fn remove_etag_headers(headers: &mut HeaderMap) {
320    headers.remove(ETAG);
321    headers.remove(IF_MODIFIED_SINCE);
322    headers.remove(IF_MATCH);
323    headers.remove(IF_NONE_MATCH);
324    headers.remove(IF_RANGE);
325    headers.remove(IF_UNMODIFIED_SINCE);
326}