simple_ssr_rs/
lib.rs

1pub use anyhow;
2pub use salvo;
3pub use salvo::catcher::Catcher;
4#[cfg(feature = "http3")]
5pub use salvo::conn::rustls::{Keycert, RustlsConfig};
6#[cfg(feature = "http3")]
7use salvo::conn::tcp::TcpAcceptor;
8
9use salvo::prelude::*;
10use salvo::serve_static::StaticDir;
11pub use serde_json::{self, Value};
12use std::{collections::HashMap, marker::PhantomData, sync::Arc};
13pub use tera::{self, Context, Filter, Function, Tera};
14pub use tokio::{self};
15
16type TeraFunctionMap = HashMap<String, Arc<dyn Function + 'static>>;
17type TeraFilterMap = HashMap<String, Arc<dyn Filter + 'static>>;
18type MetaInfoCollector =
19    Option<Arc<dyn Fn(&Request) -> HashMap<String, Value> + 'static + Send + Sync>>;
20
21type HookViewPathHandlerType = Option<Arc<dyn Fn(&mut Request, String) -> String + Send + Sync>>;
22struct CallableObjectForTera<F: ?Sized>(Arc<F>);
23
24impl<F: Function + ?Sized> Function for CallableObjectForTera<F> {
25    fn call(&self, args: &HashMap<String, Value>) -> tera::Result<Value> {
26        self.0.call(args)
27    }
28}
29
30impl<F: Filter + ?Sized> Filter for CallableObjectForTera<F> {
31    fn filter(&self, value: &Value, args: &HashMap<String, Value>) -> tera::Result<Value> {
32        self.0.filter(value, args)
33    }
34}
35
36pub struct Http3Certification {
37    pub cert: std::path::PathBuf,
38    pub key: std::path::PathBuf,
39}
40
41pub struct SSRender<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> = anyhow::Error> {
42    pub_assets_dir_name: String,
43    tmpl_dir_name: String,
44    host: String,
45    tmpl_func_map: TeraFunctionMap,
46    tmpl_filter_map: TeraFilterMap,
47    ctx_generator: MetaInfoCollector,
48    phantom_data_: PhantomData<ErrorWriter>,
49    default_view_file_postfix: String,
50    default_view_file_name: String,
51    listing_assets: bool,
52    default_asset_filename: Option<String>,
53    #[cfg(feature = "http3")]
54    use_http3: Option<Http3Certification>,
55    hook_view_path: HookViewPathHandlerType,
56}
57impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> + Send + Sync + 'static>
58    SSRender<ErrorWriter>
59{
60    pub fn new(host: &str) -> Self {
61        Self {
62            pub_assets_dir_name: "public".to_owned(),
63            tmpl_dir_name: "templates".to_owned(),
64            host: host.to_owned(),
65            tmpl_func_map: HashMap::new(),
66            tmpl_filter_map: HashMap::new(),
67            ctx_generator: None,
68            phantom_data_: PhantomData,
69            default_view_file_postfix: "html".to_owned(),
70            default_view_file_name: "index.html".to_owned(),
71            listing_assets: true,
72            default_asset_filename: None,
73            #[cfg(feature = "http3")]
74            use_http3: None,
75            hook_view_path: None,
76        }
77    }
78
79    pub fn host(&self) -> &str {
80        &self.host
81    }
82
83    pub fn set_pub_dir_name(&mut self, path: &str) {
84        self.pub_assets_dir_name = path.to_owned();
85    }
86
87    pub fn set_tmpl_dir_name(&mut self, path: &str) {
88        self.tmpl_dir_name = path.to_owned();
89    }
90
91    pub fn register_function<F: Function + 'static>(&mut self, k: String, f: F) {
92        self.tmpl_func_map.insert(k, Arc::new(f));
93    }
94
95    pub fn rm_registed_function(&mut self, k: String) {
96        self.tmpl_func_map.remove(&k);
97    }
98
99    pub fn registed_functions(&self) -> &TeraFunctionMap {
100        &self.tmpl_func_map
101    }
102
103    pub fn register_filter<F: Filter + 'static>(&mut self, k: String, f: F) {
104        self.tmpl_filter_map.insert(k, Arc::new(f));
105    }
106
107    pub fn rm_registed_filter(&mut self, k: String) {
108        self.tmpl_filter_map.remove(&k);
109    }
110
111    pub fn registed_filters(&self) -> &TeraFilterMap {
112        &self.tmpl_filter_map
113    }
114
115    pub fn pub_dir_name(&self) -> &str {
116        &self.pub_assets_dir_name
117    }
118
119    pub fn tmpl_dir_name(&self) -> &str {
120        &self.tmpl_dir_name
121    }
122
123    pub fn set_ctx_generator(
124        &mut self,
125        f: impl Fn(&Request) -> HashMap<String, Value> + 'static + Send + Sync,
126    ) {
127        self.ctx_generator = Some(Arc::new(f));
128    }
129
130    pub fn rm_ctx_generator(&mut self) {
131        self.ctx_generator = None;
132    }
133
134    pub fn gen_tera_builder(&self) -> TeraBuilder {
135        TeraBuilder::new(
136            format!("{}/**/*", self.tmpl_dir_name),
137            self.tmpl_func_map.clone(),
138            self.tmpl_filter_map.clone(),
139            self.ctx_generator.clone(),
140        )
141    }
142
143    pub fn set_default_file_postfix(&mut self, postfix: &str) {
144        self.default_view_file_postfix = postfix.to_owned();
145    }
146
147    pub fn default_file_postfix(&self) -> &str {
148        &self.default_view_file_postfix
149    }
150
151    pub fn set_listing_assets(&mut self, v: bool) {
152        self.listing_assets = v;
153    }
154
155    pub fn listing_assets(&self) -> bool {
156        self.listing_assets
157    }
158
159    pub fn set_default_assets_filename(&mut self, v: &str) {
160        self.default_asset_filename = Some(v.to_owned());
161    }
162
163    pub fn default_assets_filename(&self) -> &Option<String> {
164        &self.default_asset_filename
165    }
166    #[cfg(feature = "http3")]
167    pub fn set_use_http3(&mut self, cert: Http3Certification) {
168        self.use_http3 = Some(cert);
169    }
170    #[cfg(feature = "http3")]
171    pub fn use_http3(&self) -> Option<&Http3Certification> {
172        self.use_http3.as_ref()
173    }
174
175    pub fn set_hook_view_path<F: Fn(&mut Request, String) -> String + 'static + Send + Sync>(
176        &mut self,
177        hook: Option<F>,
178    ) {
179        match hook {
180            Some(f) => {
181                self.hook_view_path = Some(Arc::new(f));
182            }
183            None => {
184                self.hook_view_path = None;
185            }
186        }
187    }
188
189    pub fn hook_view_path(&self) -> &HookViewPathHandlerType {
190        &self.hook_view_path
191    }
192
193    pub async fn serve(&self, extend_router: Option<Router>, catcher: Option<Catcher>) {
194        let pub_assets_router = Router::with_path(format!("{}/<**>", self.pub_assets_dir_name))
195            .get(
196                StaticDir::new([&self.pub_assets_dir_name])
197                    .defaults(match &self.default_asset_filename {
198                        Some(v) => {
199                            vec![v.to_owned()]
200                        }
201                        None => {
202                            vec![]
203                        }
204                    })
205                    .auto_list(self.listing_assets),
206            );
207        let view_router = Router::with_path("/<**rest_path>").get(ViewHandler::<ErrorWriter>::new(
208            self.gen_tera_builder(),
209            self.default_view_file_postfix.clone(),
210            self.default_view_file_name.clone(),
211            self.hook_view_path.clone(),
212        ));
213        //let router = Router::new();
214
215        let router = match extend_router {
216            Some(r) => r,
217            None => Router::new(),
218        };
219        let router = router.push(pub_assets_router);
220        let router = router.push(view_router);
221        #[cfg(feature = "http3")]
222        enum VariantAcceptor<U> {
223            NonHttp3(TcpAcceptor),
224            Http3(U),
225        }
226
227        #[cfg(feature = "http3")]
228        let var_acceptor = match self.use_http3.as_ref() {
229            Some(cert) => {
230                let cert_bytes = tokio::fs::read(&cert.cert).await.unwrap();
231                let key_bytes = tokio::fs::read(&cert.key).await.unwrap();
232                let config = RustlsConfig::new(
233                    Keycert::new()
234                        .cert(cert_bytes.as_slice())
235                        .key(key_bytes.as_slice()),
236                );
237                let listener = TcpListener::new(self.host.clone()).rustls(config.clone());
238                let acceptor = QuinnListener::new(config, self.host.clone())
239                    .join(listener)
240                    .bind()
241                    .await;
242                VariantAcceptor::Http3(acceptor)
243            }
244            None => {
245                let acceptor = TcpListener::new(&self.host).bind().await;
246                VariantAcceptor::NonHttp3(acceptor)
247            }
248        };
249
250        match catcher {
251            Some(catcher) => {
252                let service = Service::new(router).catcher(catcher);
253                #[cfg(feature = "http3")]
254                {
255                    match var_acceptor {
256                        VariantAcceptor::Http3(acceptor) => {
257                            Server::new(acceptor).serve(service).await;
258                        }
259                        VariantAcceptor::NonHttp3(acceptor) => {
260                            Server::new(acceptor).serve(service).await;
261                        }
262                    }
263                }
264                #[cfg(not(feature = "http3"))]
265                {
266                    let acceptor = TcpListener::new(&self.host).bind().await;
267                    Server::new(acceptor).serve(service).await;
268                }
269            }
270            None => {
271                #[cfg(feature = "http3")]
272                {
273                    match var_acceptor {
274                        VariantAcceptor::Http3(acceptor) => {
275                            Server::new(acceptor).serve(router).await;
276                        }
277                        VariantAcceptor::NonHttp3(acceptor) => {
278                            Server::new(acceptor).serve(router).await;
279                        }
280                    }
281                }
282                #[cfg(not(feature = "http3"))]
283                {
284                    let acceptor = TcpListener::new(&self.host).bind().await;
285                    Server::new(acceptor).serve(router).await;
286                }
287            }
288        };
289    }
290}
291
292pub struct TeraBuilder {
293    tpl_dir: String,
294    tpl_funcs: TeraFunctionMap,
295    tpl_filters: TeraFilterMap,
296    ctx_generator: MetaInfoCollector,
297}
298impl TeraBuilder {
299    pub fn new(
300        tpl_dir: String,
301        tpl_funcs: TeraFunctionMap,
302        tpl_filters: TeraFilterMap,
303        ctx_generator: MetaInfoCollector,
304    ) -> Self {
305        Self {
306            tpl_dir,
307            tpl_funcs,
308            tpl_filters,
309            ctx_generator,
310        }
311    }
312
313    fn register_utilities(&self, tera: &mut Tera) {
314        for (k, v) in &self.tpl_funcs {
315            tera.register_function(k, CallableObjectForTera(Arc::clone(v)));
316        }
317        for (k, v) in &self.tpl_filters {
318            tera.register_filter(k, CallableObjectForTera(Arc::clone(v)));
319        }
320    }
321
322    pub fn build(&self, ctx: Context) -> tera::Result<(Tera, Context)> {
323        let mut tera = Tera::new(&self.tpl_dir)?;
324        self.register_utilities(&mut tera);
325        tera.register_filter(
326            "json_decode",
327            |v: &Value, _args: &HashMap<String, Value>| -> tera::Result<Value> {
328                let v = v
329                    .as_str()
330                    .ok_or(tera::Error::msg("value must be a json object string"))?;
331                let v = serde_json::from_str::<Value>(v)?;
332                Ok(v)
333            },
334        );
335        tera.register_function("include_file", generate_include(tera.clone(), ctx.clone()));
336        Ok((tera, ctx))
337    }
338
339    pub fn gen_context(&self, req: &Request) -> Context {
340        match self.ctx_generator {
341            Some(ref collect) => {
342                let mut context = Context::new();
343                for (k, val) in collect(req) {
344                    context.insert(k, &val);
345                }
346                context
347            }
348            None => Context::default(),
349        }
350    }
351}
352
353struct ViewHandler<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> = anyhow::Error> {
354    tera_builder: TeraBuilder,
355    phantom_data_: PhantomData<ErrorWriter>,
356    default_postfix: String,
357    default_view_file_name: String,
358    hook_view_path: HookViewPathHandlerType,
359}
360impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error>> ViewHandler<ErrorWriter> {
361    fn new(
362        tera_builder: TeraBuilder,
363        default_postfix: String,
364        default_view_file_name: String,
365        hook_view_path: HookViewPathHandlerType,
366    ) -> Self {
367        Self {
368            tera_builder,
369            phantom_data_: PhantomData,
370            default_postfix,
371            default_view_file_name,
372            hook_view_path,
373        }
374    }
375}
376#[handler]
377impl<ErrorWriter: Writer + From<anyhow::Error> + From<tera::Error> + Send + Sync + 'static>
378    ViewHandler<ErrorWriter>
379{
380    async fn handle(
381        &self,
382        req: &mut Request,
383        _depot: &mut Depot,
384        res: &mut Response,
385    ) -> Result<(), ErrorWriter> {
386        let Some(path) = req.param::<String>("**rest_path") else {
387            res.status_code(StatusCode::BAD_REQUEST);
388            return Err(anyhow::format_err!("invalid request path").into());
389        };
390        let ctx = self.tera_builder.gen_context(req);
391        let path = if path.is_empty() {
392            self.default_view_file_name.to_string()
393        } else {
394            match path.rfind('.') {
395                Some(_) => path,
396                None => {
397                    format!("{path}.{}", self.default_postfix)
398                }
399            }
400        };
401        let path = match &self.hook_view_path {
402            Some(f) => f(&mut *req, path),
403            None => path,
404        };
405        if !cfg!(debug_assertions) {
406            let (tera, ctx) = self.tera_builder.build(ctx.clone())?;
407            match tera.render(&path, &ctx) {
408                Ok(html) => {
409                    res.render(Text::Html(html));
410                }
411                Err(e) => {
412                    if let tera::ErrorKind::TemplateNotFound(_) = &e.kind {
413                        res.status_code(StatusCode::NOT_FOUND);
414                    } else {
415                        res.status_code(StatusCode::BAD_REQUEST);
416                    }
417                    return Err(anyhow::format_err!("{}", e.to_string()).into());
418                }
419            };
420        } else {
421            match self.tera_builder.build(ctx.clone()) {
422                Ok((tera, ctx)) => match tera.render(&path, &ctx) {
423                    Ok(s) => {
424                        res.render(Text::Html(s));
425                    }
426                    Err(e) => {
427                        if let tera::ErrorKind::TemplateNotFound(_) = &e.kind {
428                            res.status_code(StatusCode::NOT_FOUND);
429                        } else {
430                            res.status_code(StatusCode::BAD_REQUEST);
431                        }
432                        return Err(anyhow::format_err!("{e:?}").into());
433                    }
434                },
435                Err(e) => {
436                    res.status_code(StatusCode::BAD_REQUEST);
437                    return Err(anyhow::format_err!("{e:?}").into());
438                }
439            };
440        }
441        Ok(())
442    }
443}
444
445fn generate_include(tera: Tera, parent: Context) -> impl Function {
446    move |args: &HashMap<String, Value>| -> tera::Result<Value> {
447        let Some(file_path) = args.get("path") else {
448            return Err(tera::Error::msg("file does not exist in the template path"));
449        };
450        match args.get("context") {
451            Some(v) => {
452                //println!("value === {v}");
453                let context_value = v
454                    .as_str()
455                    .ok_or(tera::Error::msg("context must be a json object string"))?;
456                let v = serde_json::from_str::<Value>(context_value)?;
457                let mut context = Context::from_value(serde_json::json!({ "context": v }))?;
458                let mut tera = tera.clone();
459                context.insert("__Parent", &parent.clone().into_json());
460                tera.register_function(
461                    "include_file",
462                    generate_include(tera.clone(), context.clone()),
463                );
464                let r = tera
465                    .render(
466                        file_path
467                            .as_str()
468                            .ok_or(tera::Error::msg("template render error"))?,
469                        &context,
470                    )?
471                    .to_string();
472                Ok(Value::String(r))
473            }
474            None => {
475                let mut context =
476                    Context::from_value(serde_json::json!({ "context": Value::Null }))?;
477                let mut tera = tera.clone();
478                context.insert("__Parent", &parent.clone().into_json());
479                tera.register_function(
480                    "include_file",
481                    generate_include(tera.clone(), context.clone()),
482                );
483                let r = tera
484                    .render(
485                        file_path
486                            .as_str()
487                            .ok_or(tera::Error::msg("template render error"))?,
488                        &context,
489                    )?
490                    .to_string();
491                return Ok(Value::String(r));
492            }
493        }
494    }
495}
496
497#[macro_export]
498macro_rules! ssr_work {
499    ($e:expr, None, $catcher:expr) => {
500        $crate::tokio::runtime::Builder::new_multi_thread()
501            .enable_all()
502            .build()
503            .unwrap()
504            .block_on(async {
505                $e.serve(None, Some($catcher)).await;
506            });
507    };
508    ($e:expr, $router:expr, $catcher:expr) => {
509        $crate::tokio::runtime::Builder::new_multi_thread()
510            .enable_all()
511            .build()
512            .unwrap()
513            .block_on(async {
514                $e.serve(Some($router), Some($catcher)).await;
515            });
516    };
517    ($e:expr, $router:expr) => {
518        $crate::tokio::runtime::Builder::new_multi_thread()
519            .enable_all()
520            .build()
521            .unwrap()
522            .block_on(async {
523                $e.serve(Some($router), None).await;
524            });
525    };
526    ($e:expr) => {
527        $crate::tokio::runtime::Builder::new_multi_thread()
528            .enable_all()
529            .build()
530            .unwrap()
531            .block_on(async {
532                $e.serve(None, None).await;
533            });
534    };
535}