trillium_html_rewriter/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11#![doc = "../README.md"]
12
13use cfg_if::cfg_if;
14pub use lol_async::html;
15use lol_async::{html::Settings, rewrite};
16use mime::Mime;
17use std::{future::Future, str::FromStr};
18use trillium::{
19    async_trait, Body, Conn, Handler,
20    KnownHeaderName::{ContentLength, ContentType},
21};
22
23/**
24trillium handler for html rewriting
25*/
26pub struct HtmlRewriter {
27    settings: Box<dyn Fn() -> Settings<'static, 'static> + Send + Sync + 'static>,
28}
29
30impl std::fmt::Debug for HtmlRewriter {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("HtmlRewriter").finish()
33    }
34}
35
36fn spawn_local(fut: impl Future + 'static) {
37    cfg_if! {
38        if #[cfg(feature = "async-std")] {
39            async_std_crate::task::spawn_local(fut);
40        } else if #[cfg(feature = "smol")] {
41            async_global_executor::spawn_local(fut).detach();
42        } else if #[cfg(feature = "tokio")] {
43            tokio_crate::task::spawn_local(fut);
44        } else {
45            async_global_executor::spawn_local(fut).detach();
46        }
47    }
48}
49
50#[async_trait]
51impl Handler for HtmlRewriter {
52    async fn run(&self, mut conn: Conn) -> Conn {
53        let html = conn
54            .headers_mut()
55            .get_str(ContentType)
56            .and_then(|c| Mime::from_str(c).ok())
57            .map(|m| m.subtype() == "html")
58            .unwrap_or_default();
59
60        if html && conn.inner().response_body().is_some() {
61            let body = conn.inner_mut().take_response_body().unwrap();
62            let (fut, reader) = rewrite(body, (self.settings)());
63            spawn_local(fut);
64            conn.headers_mut().remove(ContentLength); // we no longer know the content length, if we ever did
65            conn.with_body(Body::new_streaming(reader, None))
66        } else {
67            conn
68        }
69    }
70}
71
72impl HtmlRewriter {
73    /**
74    construct a new html rewriter from the provided `fn() ->
75    Settings`. See [`lol_async::html::Settings`] for more information.
76     */
77    pub fn new(f: impl Fn() -> Settings<'static, 'static> + Send + Sync + 'static) -> Self {
78        Self {
79            settings: Box::new(f)
80                as Box<dyn Fn() -> Settings<'static, 'static> + Send + Sync + 'static>,
81        }
82    }
83}