spa_rs/
lib.rs

1//! spa-rs is a library who can embed all SPA web application files (dist static file),
2//! and release as a single binary executable.
3//!
4//! It based-on [axum] and [rust_embed]
5//!
6//! It reexported all axum module for convenient use.
7//! # Example
8//! ```no_run
9//! use spa_rs::spa_server_root;
10//! use spa_rs::SpaServer;
11//! use spa_rs::routing::{get, Router};
12//! use anyhow::Result;
13//!
14//! spa_server_root!("web/dist");           // specific your SPA dist file location
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<()> {
18//!     let data = String::new();           // server context can be acccess by [axum::Extension]
19//!     let mut srv = SpaServer::new()?
20//!         .port(3000)
21//!         .data(data)
22//!         .static_path("/png", "web")     // static file generated in runtime
23//!         .route("/api", Router::new()
24//!             .route("/get", get(|| async { "get works" })
25//!         )
26//!     );
27//!     srv.run(spa_server_root!()).await?;  
28//!
29//!     Ok(())
30//! }
31//! ```
32//!
33//! # Session
34//! See [session] module for more detail.
35//!
36//! # Dev
37//! When writing SPA application, you may want use hot-reload functionallity provided
38//! by SPA framework. such as [`vite dev`] or [`ng serve`].
39//!
40//! You can use spa-rs to reverse proxy all static requests to SPA framework. (need enable `reverse-proxy` feature)
41//!
42//! ## Example
43//! ```ignore
44//!   let forward_addr = "http://localhost:1234";
45//!   srv.reverse_proxy(forward_addr.parse()?);
46//! ```
47use anyhow::{anyhow, Context, Result};
48use axum::{
49    body::Bytes,
50    body::HttpBody,
51    extract::{Host, Request},
52    http::HeaderValue,
53    response::Response,
54    routing::{any, get_service, Route},
55};
56#[cfg(feature = "openssl")]
57use axum_server::tls_openssl::OpenSSLConfig;
58#[cfg(feature = "rustls")]
59use axum_server::tls_rustls::RustlsConfig;
60use http::{
61    header::{self},
62    StatusCode,
63};
64#[cfg(feature = "reverse-proxy")]
65use http::{Method, Uri};
66use log::{debug, error, warn};
67use std::{
68    collections::HashMap,
69    convert::Infallible,
70    env::current_exe,
71    fs::{self, create_dir_all},
72    net::SocketAddr,
73    path::{Path, PathBuf},
74};
75use tower::{Layer, Service, ServiceExt as TowerServiceExt};
76use tower_http::{
77    services::{ServeDir, ServeFile},
78    set_header::SetResponseHeaderLayer,
79};
80
81pub mod rust_embed {
82    pub use rust_embed::*;
83}
84
85pub use axum::*;
86pub mod auth;
87pub mod session;
88pub use axum::debug_handler;
89pub use axum_help::*;
90
91/// A server wrapped axum server.
92///
93/// It can:
94/// - serve static files in SPA root path
95/// - serve API requests in router
96/// - fallback to SPA static file when route matching failed
97///     - if still get 404, it will redirect to SPA index.html
98///
99#[derive(Default)]
100pub struct SpaServer<T = ()>
101where
102    T: Clone + Send + Sync + 'static,
103{
104    static_path: Vec<(String, PathBuf)>,
105    port: u16,
106    main_router: Router,
107    api_router: Router,
108    data: Option<T>,
109    forward: Option<String>,
110    release_path: PathBuf,
111    extra_layer: Vec<Box<dyn FnOnce(Router) -> Router>>,
112    host_routers: HashMap<String, Router>,
113}
114
115#[cfg(feature = "reverse-proxy")]
116async fn forwarded_to_dev(
117    Extension(forward_addr): Extension<String>,
118    uri: Uri,
119    method: Method,
120) -> HttpResult<Response> {
121    compile_error!("Can not use now, wait for reqwest upgrade hyper to 1.0");
122    // use http::uri::Scheme;
123
124    // if method == Method::GET {
125    //     let client = reqwest::Client::builder().no_proxy().build()?;
126    //     let mut parts = uri.into_parts();
127    //     parts.authority = Some(forward_addr.parse()?);
128    //     if parts.scheme.is_none() {
129    //         parts.scheme = Some(Scheme::HTTP);
130    //     }
131    //     let url = Uri::from_parts(parts)?.to_string();
132
133    //     println!("forward url: {}", url);
134    //     let response = client.get(url).send().await?;
135    //     let status = response.status();
136    //     let headers = response.headers().clone();
137    //     let bytes = response.bytes().await?;
138
139    //     let mut response = Response::builder().status(status);
140    //     *(response.headers_mut().unwrap()) = headers;
141    //     let response = response.body(bytes)?;
142    //     return Ok(response);
143    // }
144
145    // Err(HttpError {
146    //     message: "Method not allowed".to_string(),
147    //     status_code: StatusCode::METHOD_NOT_ALLOWED,
148    // })
149    todo!()
150}
151
152#[cfg(not(feature = "reverse-proxy"))]
153async fn forwarded_to_dev() {
154    unreachable!("reverse-proxy not enabled, should never call forwarded_to_dev")
155}
156
157impl<T> SpaServer<T>
158where
159    T: Clone + Send + Sync + 'static,
160{
161    /// Just new(), nothing special
162    pub fn new() -> Result<Self> {
163        Ok(Self {
164            static_path: Vec::new(),
165            port: 8080,
166            main_router: Router::new(),
167            forward: None,
168            release_path: current_exe()?
169                .parent()
170                .ok_or_else(|| anyhow!("no parent in current_exe"))?
171                .join(format!(".{}_static_files", env!("CARGO_PKG_NAME"))),
172            extra_layer: Vec::new(),
173            host_routers: HashMap::new(),
174            api_router: Router::new(),
175            data: None,
176        })
177    }
178
179    /// Specific server context data
180    ///
181    /// This is similar to [axum middleware](https://docs.rs/axum/latest/axum/#middleware)
182    pub fn data(mut self, data: T) -> Self
183    where
184        T: Clone + Send + Sync + 'static,
185    {
186        self.data = Some(data);
187        self
188    }
189
190    /// Specific an axum layer to server
191    ///
192    /// This is similar to [axum middleware](https://docs.rs/axum/latest/axum/#middleware)
193    pub fn layer<L, NewResBody>(mut self, layer: L) -> Self
194    where
195        L: Layer<Route> + Clone + Send + 'static,
196        L::Service: Service<Request, Response = Response<NewResBody>, Error = Infallible>
197            + Clone
198            + Send
199            + 'static,
200        <L::Service as Service<Request>>::Future: Send + 'static,
201        NewResBody: HttpBody<Data = Bytes> + Send + 'static,
202        NewResBody::Error: Into<BoxError>,
203    {
204        self.extra_layer.push(Box::new(move |app| app.layer(layer)));
205        self
206    }
207
208    /// make a reverse proxy which redirect all SPA requests to dev server, such as `ng serve`, `vite`.  
209    ///
210    /// it's useful when debugging UI
211    #[cfg(feature = "reverse-proxy")]
212    #[cfg_attr(docsrs, doc(cfg(feature = "reverse-proxy")))]
213    pub fn reverse_proxy(mut self, addr: impl Into<String>) -> Self {
214        self.forward = Some(addr.into());
215        self
216    }
217
218    /// static file release path in runtime
219    ///
220    /// Default path is /tmp/[env!(CARGO_PKG_NAME)]_static_files
221    pub fn release_path(mut self, rp: impl Into<PathBuf>) -> Self {
222        self.release_path = rp.into();
223        self
224    }
225
226    /// Run the spa server forever
227    pub async fn run<Root>(self, root: Root) -> Result<()>
228    where
229        Root: SpaStatic,
230    {
231        self.run_raw(Some(root), None).await
232    }
233
234    /// Run the spa server with tls
235    #[cfg(any(feature = "openssl", feature = "rustls"))]
236    pub async fn run_tls<Root>(self, root: Root, config: HttpsConfig) -> Result<()>
237    where
238        Root: SpaStatic,
239    {
240        self.run_raw(Some(root), Some(config)).await
241    }
242
243    /// Run the spa server without spa root
244    pub async fn run_api(self) -> Result<()> {
245        self.run_raw::<ApiOnly>(None, None).await
246    }
247
248    /// Run the spa server with tls and without spa root
249    #[cfg(any(feature = "openssl", feature = "rustls"))]
250    pub async fn run_api_tls(self, config: HttpsConfig) -> Result<()> {
251        self.run_raw::<ApiOnly>(None, Some(config)).await
252    }
253
254    /// Run the spa server with or without spa root, and with or without tls
255    async fn run_raw<Root>(mut self, root: Option<Root>, config: Option<HttpsConfig>) -> Result<()>
256    where
257        Root: SpaStatic,
258    {
259        if let Some(root) = root {
260            let embeded_dir = root.release(self.release_path)?;
261            let index_file = embeded_dir.clone().join("index.html");
262
263            self.api_router = if let Some(addr) = self.forward {
264                self.api_router
265                    .fallback(forwarded_to_dev)
266                    .layer(Extension(addr))
267            } else {
268                self.api_router.fallback_service(
269                    get_service(ServeDir::new(&embeded_dir).fallback(ServeFile::new(index_file)))
270                        .layer(Self::add_cache_control())
271                        .handle_error(|e: anyhow::Error| async move {
272                            (
273                                StatusCode::INTERNAL_SERVER_ERROR,
274                                format!(
275                            "Unhandled internal server error {:?} when serve embeded path {}",
276                            e,
277                            embeded_dir.display()
278                        ),
279                            )
280                        }),
281                )
282            };
283        }
284
285        for sf in self.static_path {
286            self.api_router = self.api_router.nest_service(
287                &sf.0,
288                get_service(ServeDir::new(&sf.1))
289                    .layer(Self::add_cache_control())
290                    .handle_error(|e: anyhow::Error| async move {
291                        (
292                            StatusCode::INTERNAL_SERVER_ERROR,
293                            format!(
294                                "Unhandled internal server error {:?} when serve static path {}",
295                                e,
296                                sf.1.display()
297                            ),
298                        )
299                    }),
300            )
301        }
302
303        let main_handler = |Host(hostname): Host, request: Request| async move {
304            if let Some(router) = self.host_routers.remove(&hostname) {
305                router.oneshot(request).await
306            } else {
307                self.api_router.oneshot(request).await
308            }
309        };
310        self.main_router = Router::new()
311            .route("/", any(main_handler.clone()))
312            .route("/*path", any(main_handler));
313
314        if let Some(data) = self.data {
315            self.main_router = self.main_router.layer(Extension(data));
316        }
317
318        for layer in self.extra_layer {
319            self.main_router = layer(self.main_router)
320        }
321
322        let addr = format!("0.0.0.0:{}", self.port).parse()?;
323        if let Some(_config) = config {
324            #[cfg(all(feature = "openssl", feature = "rustls"))]
325            compile_error!("Feature openssl and Feature rustls can not be enabled together");
326
327            #[cfg(any(feature = "openssl", feature = "rustls"))]
328            {
329                #[cfg(feature = "rustls")]
330                {
331                    axum_server::bind_rustls(
332                        addr,
333                        RustlsConfig::from_pem(_config.certificate, _config.private_key).await?,
334                    )
335                }
336                #[cfg(feature = "openssl")]
337                {
338                    let temp_dir = std::env::temp_dir().join(env!("CARGO_PKG_NAME"));
339                    std::fs::create_dir_all(&temp_dir)?;
340                    let cert_file = temp_dir.join("cert.pem");
341                    let key_file = temp_dir.join("key.pem");
342                    std::fs::write(&cert_file, &_config.certificate)?;
343                    std::fs::write(&key_file, &_config.private_key)?;
344                    axum_server::bind_openssl(
345                        addr,
346                        OpenSSLConfig::from_pem_file(cert_file, key_file)
347                            .context("openssl load pem file error")?,
348                    )
349                }
350            }
351            .serve(
352                self.main_router
353                    .into_make_service_with_connect_info::<SocketAddr>(),
354            )
355            .await?;
356        } else {
357            axum_server::bind(addr)
358                .serve(
359                    self.main_router
360                        .into_make_service_with_connect_info::<SocketAddr>(),
361                )
362                .await
363                .context("serve server error")?;
364        }
365
366        Ok(())
367    }
368
369    /// Setting up server router, see example for usage.
370    ///
371    pub fn route(mut self, path: impl AsRef<str>, router: Router) -> Self {
372        self.api_router = self.api_router.nest(path.as_ref(), router);
373        self
374    }
375
376    /// Server listening port, default is 8080
377    ///
378    pub fn port(mut self, port: u16) -> Self {
379        self.port = port;
380        self
381    }
382
383    /// Setting up a runtime static file path.
384    ///
385    /// Unlike [spa_server_root], file in this path can be changed in runtime.
386    pub fn static_path(mut self, path: impl Into<String>, dir: impl Into<PathBuf>) -> Self {
387        self.static_path.push((path.into(), dir.into()));
388        self
389    }
390
391    /// add host based router
392    ///
393    pub fn host_router(mut self, host: impl Into<String>, router: Router) -> Self {
394        self.host_routers.insert(host.into(), router);
395        self
396    }
397
398    fn add_cache_control() -> SetResponseHeaderLayer<HeaderValue> {
399        SetResponseHeaderLayer::if_not_present(
400            header::CACHE_CONTROL,
401            HeaderValue::from_static("max-age=300"),
402        )
403    }
404}
405
406pub struct HttpsConfig {
407    pub certificate: Vec<u8>,
408    pub private_key: Vec<u8>,
409}
410
411/// setup https pems   
412///
413/// ## Example
414/// ```
415/// https_pems!("/some/folder/contains/two/pem/file");
416/// ```
417///
418/// ## Caution
419/// pem file name should be [`cert.pem`] and [`key.pem`]
420///
421#[macro_export]
422macro_rules! https_pems {
423    ($path: literal) => {
424        #[derive(spa_rs::rust_embed::RustEmbed)]
425        #[folder = $path]
426        struct HttpsPems;
427    };
428
429    () => {{
430        let https_config = || -> anyhow::Result<spa_rs::HttpsConfig> {
431            let mut cert = Vec::new();
432            let mut key = Vec::new();
433            for file in HttpsPems::iter() {
434                if let Some(f) = HttpsPems::get(&file) {
435                    macro_rules! setup {
436                        ($t: expr) => {
437                            if file == format!("{}.pem", stringify!($t)) {
438                                $t = f.data.to_vec();
439                            }
440                        };
441                    }
442                    setup!(cert);
443                    setup!(key);
444                }
445            }
446
447            if cert.is_empty() || key.is_empty() {
448                anyhow::bail!("invalid ssl cert or key embed file");
449            }
450
451            Ok(spa_rs::HttpsConfig {
452                certificate: cert,
453                private_key: key,
454            })
455        };
456        https_config()
457    }};
458}
459
460/// Specific SPA dist file root path in compile time
461///
462#[macro_export]
463macro_rules! spa_server_root {
464    ($root: literal) => {
465        use spa_rs::rust_embed;
466
467        #[derive(rust_embed::RustEmbed)]
468        #[folder = $root]
469        struct StaticFiles;
470
471        impl spa_rs::SpaStatic for StaticFiles {}
472    };
473    () => {
474        StaticFiles
475    };
476}
477
478/// Used to release static file into temp dir in runtime.
479///
480pub trait SpaStatic: rust_embed::RustEmbed {
481    fn release(&self, release_path: PathBuf) -> Result<PathBuf> {
482        let target_dir = release_path;
483        if !target_dir.exists() {
484            create_dir_all(&target_dir)?;
485        }
486
487        for file in Self::iter() {
488            match Self::get(&file) {
489                Some(f) => {
490                    if let Some(p) = Path::new(file.as_ref()).parent() {
491                        let parent_dir = target_dir.join(p);
492                        create_dir_all(parent_dir)?;
493                    }
494
495                    let path = target_dir.join(file.as_ref());
496                    debug!("release static file: {}", path.display());
497                    if let Err(e) = fs::write(path, f.data) {
498                        error!("static file {} write error: {:?}", file, e);
499                    }
500                }
501                None => warn!("static file {} not found", file),
502            }
503        }
504
505        Ok(target_dir)
506    }
507}
508
509impl SpaStatic for ApiOnly {}
510impl rust_embed::RustEmbed for ApiOnly {
511    fn get(_file_path: &str) -> Option<rust_embed::EmbeddedFile> {
512        unreachable!()
513    }
514
515    fn iter() -> rust_embed::Filenames {
516        unreachable!()
517    }
518}
519
520struct ApiOnly;