spa_rs/
lib.rs

1//! spa-rs is a library who can embed all SPA web application files (dist static file),
2//! and release as a single binary executable.
3//!
4//! It based-on [axum] and [rust_embed]
5//!
6//! It reexported all axum module for convenient use.
7//! # Example
8//! ```no_run
9//! use spa_rs::spa_server_root;
10//! use spa_rs::SpaServer;
11//! use spa_rs::routing::{get, Router};
12//! use anyhow::Result;
13//!
14//! spa_server_root!("web/dist");           // specific your SPA dist file location
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<()> {
18//!     let data = String::new();           // server context can be acccess by [axum::Extension]
19//!     let mut srv = SpaServer::new()?
20//!         .port(3000)
21//!         .data(data)
22//!         .static_path("/png", "web")     // static file generated in runtime
23//!         .route("/api", Router::new()
24//!             .route("/get", get(|| async { "get works" })
25//!         )
26//!     );
27//!     srv.run(spa_server_root!()).await?;  
28//!
29//!     Ok(())
30//! }
31//! ```
32//!
33//! # Session
34//! See [session] module for more detail.
35//!
36//! # Dev
37//! When writing SPA application, you may want use hot-reload functionallity provided
38//! by SPA framework. such as [`vite dev`] or [`ng serve`].
39//!
40//! You can use spa-rs to reverse proxy all static requests to SPA framework. (need enable `reverse-proxy` feature)
41//!
42//! ## Example
43//! ```ignore
44//!   let forward_addr = "http://localhost:1234";
45//!   srv.reverse_proxy(forward_addr.parse()?);
46//! ```
47use anyhow::{anyhow, Context, Result};
48use axum::{
49    body::{Bytes, HttpBody},
50    extract::Request,
51    http::HeaderValue,
52    response::Response,
53    routing::{get_service, Route},
54};
55#[cfg(feature = "openssl")]
56use axum_server::tls_openssl::OpenSSLConfig;
57#[cfg(feature = "rustls")]
58use axum_server::tls_rustls::RustlsConfig;
59use futures_util::future::BoxFuture;
60use http::{
61    header::{self},
62    StatusCode,
63};
64#[cfg(feature = "reverse-proxy")]
65use http::{Method, Uri};
66use log::{debug, error, warn};
67pub use rust_embed;
68use std::{
69    collections::HashMap,
70    convert::Infallible,
71    env::current_exe,
72    fs::{self, create_dir_all},
73    net::SocketAddr,
74    path::PathBuf,
75    sync::Arc,
76};
77use tower::{util::ServiceExt as TowerServiceExt, Layer, Service};
78use tower_http::{
79    services::{ServeDir, ServeFile},
80    set_header::SetResponseHeaderLayer,
81};
82
83pub use axum::*;
84pub mod auth;
85pub mod session;
86pub use axum::debug_handler;
87pub use axum_help::*;
88
89/// A server wrapped axum server.
90///
91/// It can:
92/// - serve static files in SPA root path
93/// - serve API requests in router
94/// - fallback to SPA static file when route matching failed
95///     - if still get 404, it will redirect to SPA index.html
96///
97#[derive(Default)]
98pub struct SpaServer<T = ()>
99where
100    T: Clone + Send + Sync + 'static,
101{
102    static_path: Vec<(String, PathBuf)>,
103    port: u16,
104    router: Router,
105    data: Option<T>,
106    forward: Option<String>,
107    release_path: PathBuf,
108    extra_layer: Vec<Box<dyn FnOnce(Router) -> Router>>,
109    host_routers: HashMap<String, Router>,
110}
111
112#[axum::debug_handler]
113#[cfg(feature = "reverse-proxy")]
114async fn forwarded_to_dev(
115    Extension(forward_addr): Extension<String>,
116    uri: Uri,
117    method: Method,
118) -> HttpResult<Response> {
119    use axum::http::Response;
120    use http::uri::Scheme;
121
122    if method == Method::GET {
123        let client = reqwest::Client::builder().no_proxy().build()?;
124        let mut parts = uri.into_parts();
125        parts.authority = Some(forward_addr.parse()?);
126        if parts.scheme.is_none() {
127            parts.scheme = Some(Scheme::HTTP);
128        }
129        let url = Uri::from_parts(parts)?.to_string();
130
131        println!("forward url: {}", url);
132        let response = client.get(url).send().await?;
133        let response: http::Response<_> = response.into();
134        let (parts, body) = response.into_parts();
135        let body = body.as_bytes().map(|b| b.to_vec()).unwrap_or_default();
136
137        let response = Response::from_parts(parts, body.into());
138        return Ok(response);
139    }
140
141    Err(HttpError {
142        message: "Method not allowed".to_string(),
143        status_code: StatusCode::METHOD_NOT_ALLOWED,
144    })
145}
146
147#[cfg(not(feature = "reverse-proxy"))]
148async fn forwarded_to_dev() {
149    unreachable!("reverse-proxy not enabled, should never call forwarded_to_dev")
150}
151
152impl<T> SpaServer<T>
153where
154    T: Clone + Send + Sync + 'static,
155{
156    /// Just new(), nothing special
157    pub fn new() -> Result<Self> {
158        Ok(Self {
159            static_path: Vec::new(),
160            port: 8080,
161            forward: None,
162            release_path: current_exe()?
163                .parent()
164                .ok_or_else(|| anyhow!("no parent in current_exe"))?
165                .join(format!(".{}_static_files", env!("CARGO_PKG_NAME"))),
166            extra_layer: Vec::new(),
167            host_routers: HashMap::new(),
168            router: Router::new(),
169            data: None,
170        })
171    }
172
173    /// Specific server context data
174    ///
175    /// This is similar to [axum middleware](https://docs.rs/axum/latest/axum/#middleware)
176    pub fn data(mut self, data: T) -> Self
177    where
178        T: Clone + Send + Sync + 'static,
179    {
180        self.data = Some(data);
181        self
182    }
183
184    /// Specific an axum layer to server
185    ///
186    /// This is similar to [axum middleware](https://docs.rs/axum/latest/axum/#middleware)
187    pub fn layer<L, NewResBody>(mut self, layer: L) -> Self
188    where
189        L: Layer<Route> + Clone + Send + 'static,
190        L::Service: Service<Request, Response = Response<NewResBody>, Error = Infallible>
191            + Clone
192            + Send
193            + 'static,
194        <L::Service as Service<Request>>::Future: Send + 'static,
195        NewResBody: HttpBody<Data = Bytes> + Send + 'static,
196        NewResBody::Error: Into<BoxError>,
197    {
198        self.extra_layer.push(Box::new(move |app| app.layer(layer)));
199        self
200    }
201
202    /// make a reverse proxy which redirect all SPA requests to dev server, such as `ng serve`, `vite`.  
203    ///
204    /// it's useful when debugging UI
205    #[cfg(feature = "reverse-proxy")]
206    #[cfg_attr(docsrs, doc(cfg(feature = "reverse-proxy")))]
207    pub fn reverse_proxy(mut self, addr: impl Into<String>) -> Self {
208        self.forward = Some(addr.into());
209        self
210    }
211
212    /// static file release path in runtime
213    ///
214    /// Default path is /tmp/[env!(CARGO_PKG_NAME)]_static_files
215    pub fn release_path(mut self, rp: impl Into<PathBuf>) -> Self {
216        self.release_path = rp.into();
217        self
218    }
219
220    /// Run the spa server forever
221    pub async fn run<Root>(self, root: Root) -> Result<()>
222    where
223        Root: SpaStatic,
224    {
225        self.run_raw(Some(root), None).await
226    }
227
228    /// Run the spa server with tls
229    #[cfg(any(feature = "openssl", feature = "rustls"))]
230    pub async fn run_tls<Root>(self, root: Root, config: HttpsConfig) -> Result<()>
231    where
232        Root: SpaStatic,
233    {
234        self.run_raw(Some(root), Some(config)).await
235    }
236
237    /// Run the spa server without spa root
238    pub async fn run_api(self) -> Result<()> {
239        self.run_raw::<ApiOnly>(None, None).await
240    }
241
242    /// Run the spa server with tls and without spa root
243    #[cfg(any(feature = "openssl", feature = "rustls"))]
244    pub async fn run_api_tls(self, config: HttpsConfig) -> Result<()> {
245        self.run_raw::<ApiOnly>(None, Some(config)).await
246    }
247
248    /// Run the spa server with or without spa root, and with or without tls
249    async fn run_raw<Root>(mut self, root: Option<Root>, config: Option<HttpsConfig>) -> Result<()>
250    where
251        Root: SpaStatic,
252    {
253        if let Some(root) = root {
254            let embeded_dir = root.release(self.release_path)?;
255            let index_file = embeded_dir.clone().join("index.html");
256
257            self.router = if let Some(addr) = self.forward {
258                self.router
259                    .fallback(forwarded_to_dev)
260                    .layer(Extension(addr))
261            } else {
262                self.router.fallback_service(
263                    get_service(ServeDir::new(&embeded_dir).fallback(ServeFile::new(index_file)))
264                        .layer(Self::add_cache_control())
265                        .handle_error(|e: anyhow::Error| async move {
266                            (
267                                StatusCode::INTERNAL_SERVER_ERROR,
268                                format!(
269                            "Unhandled internal server error {:?} when serve embeded path {}",
270                            e,
271                            embeded_dir.display()
272                        ),
273                            )
274                        }),
275                )
276            };
277        }
278
279        for sf in self.static_path {
280            self.router = self.router.nest_service(
281                &sf.0,
282                get_service(ServeDir::new(&sf.1))
283                    .layer(Self::add_cache_control())
284                    .handle_error(|e: anyhow::Error| async move {
285                        (
286                            StatusCode::INTERNAL_SERVER_ERROR,
287                            format!(
288                                "Unhandled internal server error {:?} when serve static path {}",
289                                e,
290                                sf.1.display()
291                            ),
292                        )
293                    }),
294            )
295        }
296
297        self.router = self
298            .router
299            .layer(MatchHostLayer::new(Arc::new(self.host_routers.clone())));
300
301        if let Some(data) = self.data {
302            self.router = self.router.layer(Extension(data));
303        }
304
305        for layer in self.extra_layer {
306            self.router = layer(self.router)
307        }
308
309        let addr = format!("0.0.0.0:{}", self.port).parse()?;
310        #[allow(unused_variables)]
311        if let Some(config) = config {
312            #[cfg(all(feature = "openssl", feature = "rustls"))]
313            compile_error!("Feature openssl and Feature rustls can not be enabled together");
314
315            #[cfg(any(feature = "openssl", feature = "rustls"))]
316            {
317                #[cfg(feature = "rustls")]
318                {
319                    let certificate = std::fs::read(config.certificate)?;
320                    let private_key = std::fs::read(config.private_key)?;
321                    axum_server::bind_rustls(
322                        addr,
323                        RustlsConfig::from_pem(certificate, private_key).await?,
324                    )
325                }
326                #[cfg(feature = "openssl")]
327                {
328                    axum_server::bind_openssl(
329                        addr,
330                        OpenSSLConfig::from_pem_file(config.certificate, config.private_key)
331                            .context("openssl load pem file error")?,
332                    )
333                }
334            }
335            .serve(
336                self.router
337                    .into_make_service_with_connect_info::<SocketAddr>(),
338            )
339            .await?;
340        } else {
341            axum_server::bind(addr)
342                .serve(
343                    self.router
344                        .into_make_service_with_connect_info::<SocketAddr>(),
345                )
346                .await
347                .context("serve server error")?;
348        }
349
350        Ok(())
351    }
352
353    /// Setting up server router, see example for usage.
354    ///
355    pub fn route(mut self, path: impl AsRef<str>, router: Router) -> Self {
356        self.router = self.router.nest(path.as_ref(), router);
357        self
358    }
359
360    /// Server listening port, default is 8080
361    ///
362    pub fn port(mut self, port: u16) -> Self {
363        self.port = port;
364        self
365    }
366
367    /// Setting up a runtime static file path.
368    ///
369    /// Unlike [spa_server_root], file in this path can be changed in runtime.
370    pub fn static_path(mut self, path: impl Into<String>, dir: impl Into<PathBuf>) -> Self {
371        self.static_path.push((path.into(), dir.into()));
372        self
373    }
374
375    /// add host based router
376    ///
377    pub fn host_router(mut self, host: impl Into<String>, router: Router) -> Self {
378        self.host_routers.insert(host.into(), router);
379        self
380    }
381
382    fn add_cache_control() -> SetResponseHeaderLayer<HeaderValue> {
383        SetResponseHeaderLayer::if_not_present(
384            header::CACHE_CONTROL,
385            HeaderValue::from_static("max-age=300"),
386        )
387    }
388}
389
390#[derive(Clone)]
391struct MatchHostLayer {
392    host_routers: Arc<HashMap<String, Router>>,
393}
394
395impl<S> Layer<S> for MatchHostLayer {
396    type Service = MatchHost<S>;
397
398    fn layer(&self, inner: S) -> Self::Service {
399        MatchHost {
400            inner,
401            host_routers: self.host_routers.clone(),
402        }
403    }
404}
405
406impl MatchHostLayer {
407    pub fn new(host_routers: Arc<HashMap<String, Router>>) -> Self {
408        Self { host_routers }
409    }
410}
411
412#[derive(Clone)]
413struct MatchHost<S> {
414    inner: S,
415    host_routers: Arc<HashMap<String, Router>>,
416}
417
418impl<S> Service<Request> for MatchHost<S>
419where
420    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
421    S::Future: Send + 'static,
422    S::Error: Into<Infallible>,
423{
424    type Response = S::Response;
425    type Error = S::Error;
426    type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
427
428    fn poll_ready(
429        &mut self,
430        cx: &mut std::task::Context<'_>,
431    ) -> std::task::Poll<Result<(), Self::Error>> {
432        self.inner.poll_ready(cx)
433    }
434
435    fn call(&mut self, req: Request) -> Self::Future {
436        let host_routers = self.host_routers.clone();
437        let mut srv = self.inner.clone();
438        Box::pin(async move {
439            let hostname = req
440                .headers()
441                .get(header::HOST)
442                .and_then(|h| h.to_str().ok())
443                .unwrap_or_default();
444
445            if let Some((_, router)) = host_routers.iter().find(|(k, _v)| hostname.ends_with(*k)) {
446                router.clone().oneshot(req).await
447            } else {
448                srv.call(req).await
449            }
450        })
451    }
452}
453
454pub struct HttpsConfig {
455    pub certificate: PathBuf,
456    pub private_key: PathBuf,
457}
458
459/// setup https pems   
460///
461/// ## Example
462/// ```
463/// https_pems!("/some/folder/contains/two/pem/file");
464/// ```
465///
466/// ## Caution
467/// pem file name should be [`cert.pem`] and [`key.pem`]
468///
469#[macro_export]
470macro_rules! embed_https_pems {
471    ($path: literal) => {
472        #[derive(spa_rs::rust_embed::RustEmbed)]
473        #[crate_path = "spa_rs::rust_embed"]
474        #[folder = $path]
475        struct HttpsPems;
476    };
477
478    () => {{
479        let https_config = || -> anyhow::Result<spa_rs::HttpsConfig> {
480            let mut base_path = std::env::temp_dir().join(format!(
481                "{}_{}",
482                env!("CARGO_PKG_NAME"),
483                env!("CARGO_PKG_VERSION")
484            ));
485            let _ = std::fs::create_dir_all(&base_path);
486            let mut cert_path = None;
487            let mut key_path = None;
488            for file in HttpsPems::iter() {
489                if let Some(f) = HttpsPems::get(&file) {
490                    if file == "key.pem" {
491                        key_path = Some(base_path.join("key.pem"));
492                        std::fs::write(key_path.as_ref().unwrap(), &f.data)?;
493                    }
494
495                    if file == "cert.pem" {
496                        cert_path = Some(base_path.join("cert.pem"));
497                        std::fs::write(cert_path.as_ref().unwrap(), &f.data)?;
498                    }
499                }
500            }
501
502            if cert_path.is_none() || key_path.is_none() {
503                anyhow::bail!("invalid ssl cert or key embed file");
504            }
505
506            Ok(spa_rs::HttpsConfig {
507                certificate: cert_path.unwrap(),
508                private_key: key_path.unwrap(),
509            })
510        };
511        https_config()
512    }};
513}
514
515/// Specific SPA dist file root path in compile time
516///
517#[macro_export]
518macro_rules! spa_server_root {
519    ($root: literal) => {
520        #[derive(spa_rs::rust_embed::RustEmbed)]
521        #[crate_path = "spa_rs::rust_embed"]
522        #[folder = $root]
523        struct StaticFiles;
524
525        impl spa_rs::SpaStatic for StaticFiles {}
526    };
527    () => {
528        StaticFiles
529    };
530}
531
532/// Used to release static file into temp dir in runtime.
533///
534pub trait SpaStatic: rust_embed::RustEmbed {
535    fn release(&self, release_path: PathBuf) -> Result<PathBuf> {
536        let target_dir = release_path;
537        if !target_dir.exists() {
538            create_dir_all(&target_dir)?;
539        }
540
541        for file in Self::iter() {
542            match Self::get(&file) {
543                Some(f) => {
544                    if let Some(p) = std::path::Path::new(file.as_ref()).parent() {
545                        let parent_dir = target_dir.join(p);
546                        create_dir_all(parent_dir)?;
547                    }
548
549                    let path = target_dir.join(file.as_ref());
550                    debug!("release static file: {}", path.display());
551                    if let Err(e) = fs::write(path, f.data) {
552                        error!("static file {} write error: {:?}", file, e);
553                    }
554                }
555                None => warn!("static file {} not found", file),
556            }
557        }
558
559        Ok(target_dir)
560    }
561}
562
563impl SpaStatic for ApiOnly {}
564impl rust_embed::RustEmbed for ApiOnly {
565    fn get(_file_path: &str) -> Option<rust_embed::EmbeddedFile> {
566        unreachable!()
567    }
568
569    fn iter() -> rust_embed::Filenames {
570        unreachable!()
571    }
572}
573
574struct ApiOnly;