1use anyhow::{anyhow, Context, Result};
48use axum::{
49 body::Bytes,
50 body::HttpBody,
51 extract::{Host, Request},
52 http::HeaderValue,
53 response::Response,
54 routing::{any, get_service, Route},
55};
56#[cfg(feature = "openssl")]
57use axum_server::tls_openssl::OpenSSLConfig;
58#[cfg(feature = "rustls")]
59use axum_server::tls_rustls::RustlsConfig;
60use http::{
61 header::{self},
62 StatusCode,
63};
64#[cfg(feature = "reverse-proxy")]
65use http::{Method, Uri};
66use log::{debug, error, warn};
67use std::{
68 collections::HashMap,
69 convert::Infallible,
70 env::current_exe,
71 fs::{self, create_dir_all},
72 net::SocketAddr,
73 path::{Path, PathBuf},
74};
75use tower::{Layer, Service, ServiceExt as TowerServiceExt};
76use tower_http::{
77 services::{ServeDir, ServeFile},
78 set_header::SetResponseHeaderLayer,
79};
80
81pub mod rust_embed {
82 pub use rust_embed::*;
83}
84
85pub use axum::*;
86pub mod auth;
87pub mod session;
88pub use axum::debug_handler;
89pub use axum_help::*;
90
91#[derive(Default)]
100pub struct SpaServer<T = ()>
101where
102 T: Clone + Send + Sync + 'static,
103{
104 static_path: Vec<(String, PathBuf)>,
105 port: u16,
106 main_router: Router,
107 api_router: Router,
108 data: Option<T>,
109 forward: Option<String>,
110 release_path: PathBuf,
111 extra_layer: Vec<Box<dyn FnOnce(Router) -> Router>>,
112 host_routers: HashMap<String, Router>,
113}
114
115#[cfg(feature = "reverse-proxy")]
116async fn forwarded_to_dev(
117 Extension(forward_addr): Extension<String>,
118 uri: Uri,
119 method: Method,
120) -> HttpResult<Response> {
121 compile_error!("Can not use now, wait for reqwest upgrade hyper to 1.0");
122 todo!()
150}
151
152#[cfg(not(feature = "reverse-proxy"))]
153async fn forwarded_to_dev() {
154 unreachable!("reverse-proxy not enabled, should never call forwarded_to_dev")
155}
156
157impl<T> SpaServer<T>
158where
159 T: Clone + Send + Sync + 'static,
160{
161 pub fn new() -> Result<Self> {
163 Ok(Self {
164 static_path: Vec::new(),
165 port: 8080,
166 main_router: Router::new(),
167 forward: None,
168 release_path: current_exe()?
169 .parent()
170 .ok_or_else(|| anyhow!("no parent in current_exe"))?
171 .join(format!(".{}_static_files", env!("CARGO_PKG_NAME"))),
172 extra_layer: Vec::new(),
173 host_routers: HashMap::new(),
174 api_router: Router::new(),
175 data: None,
176 })
177 }
178
179 pub fn data(mut self, data: T) -> Self
183 where
184 T: Clone + Send + Sync + 'static,
185 {
186 self.data = Some(data);
187 self
188 }
189
190 pub fn layer<L, NewResBody>(mut self, layer: L) -> Self
194 where
195 L: Layer<Route> + Clone + Send + 'static,
196 L::Service: Service<Request, Response = Response<NewResBody>, Error = Infallible>
197 + Clone
198 + Send
199 + 'static,
200 <L::Service as Service<Request>>::Future: Send + 'static,
201 NewResBody: HttpBody<Data = Bytes> + Send + 'static,
202 NewResBody::Error: Into<BoxError>,
203 {
204 self.extra_layer.push(Box::new(move |app| app.layer(layer)));
205 self
206 }
207
208 #[cfg(feature = "reverse-proxy")]
212 #[cfg_attr(docsrs, doc(cfg(feature = "reverse-proxy")))]
213 pub fn reverse_proxy(mut self, addr: impl Into<String>) -> Self {
214 self.forward = Some(addr.into());
215 self
216 }
217
218 pub fn release_path(mut self, rp: impl Into<PathBuf>) -> Self {
222 self.release_path = rp.into();
223 self
224 }
225
226 pub async fn run<Root>(self, root: Root) -> Result<()>
228 where
229 Root: SpaStatic,
230 {
231 self.run_raw(Some(root), None).await
232 }
233
234 #[cfg(any(feature = "openssl", feature = "rustls"))]
236 pub async fn run_tls<Root>(self, root: Root, config: HttpsConfig) -> Result<()>
237 where
238 Root: SpaStatic,
239 {
240 self.run_raw(Some(root), Some(config)).await
241 }
242
243 pub async fn run_api(self) -> Result<()> {
245 self.run_raw::<ApiOnly>(None, None).await
246 }
247
248 #[cfg(any(feature = "openssl", feature = "rustls"))]
250 pub async fn run_api_tls(self, config: HttpsConfig) -> Result<()> {
251 self.run_raw::<ApiOnly>(None, Some(config)).await
252 }
253
254 async fn run_raw<Root>(mut self, root: Option<Root>, config: Option<HttpsConfig>) -> Result<()>
256 where
257 Root: SpaStatic,
258 {
259 if let Some(root) = root {
260 let embeded_dir = root.release(self.release_path)?;
261 let index_file = embeded_dir.clone().join("index.html");
262
263 self.api_router = if let Some(addr) = self.forward {
264 self.api_router
265 .fallback(forwarded_to_dev)
266 .layer(Extension(addr))
267 } else {
268 self.api_router.fallback_service(
269 get_service(ServeDir::new(&embeded_dir).fallback(ServeFile::new(index_file)))
270 .layer(Self::add_cache_control())
271 .handle_error(|e: anyhow::Error| async move {
272 (
273 StatusCode::INTERNAL_SERVER_ERROR,
274 format!(
275 "Unhandled internal server error {:?} when serve embeded path {}",
276 e,
277 embeded_dir.display()
278 ),
279 )
280 }),
281 )
282 };
283 }
284
285 for sf in self.static_path {
286 self.api_router = self.api_router.nest_service(
287 &sf.0,
288 get_service(ServeDir::new(&sf.1))
289 .layer(Self::add_cache_control())
290 .handle_error(|e: anyhow::Error| async move {
291 (
292 StatusCode::INTERNAL_SERVER_ERROR,
293 format!(
294 "Unhandled internal server error {:?} when serve static path {}",
295 e,
296 sf.1.display()
297 ),
298 )
299 }),
300 )
301 }
302
303 let main_handler = |Host(hostname): Host, request: Request| async move {
304 if let Some(router) = self.host_routers.remove(&hostname) {
305 router.oneshot(request).await
306 } else {
307 self.api_router.oneshot(request).await
308 }
309 };
310 self.main_router = Router::new()
311 .route("/", any(main_handler.clone()))
312 .route("/*path", any(main_handler));
313
314 if let Some(data) = self.data {
315 self.main_router = self.main_router.layer(Extension(data));
316 }
317
318 for layer in self.extra_layer {
319 self.main_router = layer(self.main_router)
320 }
321
322 let addr = format!("0.0.0.0:{}", self.port).parse()?;
323 if let Some(_config) = config {
324 #[cfg(all(feature = "openssl", feature = "rustls"))]
325 compile_error!("Feature openssl and Feature rustls can not be enabled together");
326
327 #[cfg(any(feature = "openssl", feature = "rustls"))]
328 {
329 #[cfg(feature = "rustls")]
330 {
331 axum_server::bind_rustls(
332 addr,
333 RustlsConfig::from_pem(_config.certificate, _config.private_key).await?,
334 )
335 }
336 #[cfg(feature = "openssl")]
337 {
338 let temp_dir = std::env::temp_dir().join(env!("CARGO_PKG_NAME"));
339 std::fs::create_dir_all(&temp_dir)?;
340 let cert_file = temp_dir.join("cert.pem");
341 let key_file = temp_dir.join("key.pem");
342 std::fs::write(&cert_file, &_config.certificate)?;
343 std::fs::write(&key_file, &_config.private_key)?;
344 axum_server::bind_openssl(
345 addr,
346 OpenSSLConfig::from_pem_file(cert_file, key_file)
347 .context("openssl load pem file error")?,
348 )
349 }
350 }
351 .serve(
352 self.main_router
353 .into_make_service_with_connect_info::<SocketAddr>(),
354 )
355 .await?;
356 } else {
357 axum_server::bind(addr)
358 .serve(
359 self.main_router
360 .into_make_service_with_connect_info::<SocketAddr>(),
361 )
362 .await
363 .context("serve server error")?;
364 }
365
366 Ok(())
367 }
368
369 pub fn route(mut self, path: impl AsRef<str>, router: Router) -> Self {
372 self.api_router = self.api_router.nest(path.as_ref(), router);
373 self
374 }
375
376 pub fn port(mut self, port: u16) -> Self {
379 self.port = port;
380 self
381 }
382
383 pub fn static_path(mut self, path: impl Into<String>, dir: impl Into<PathBuf>) -> Self {
387 self.static_path.push((path.into(), dir.into()));
388 self
389 }
390
391 pub fn host_router(mut self, host: impl Into<String>, router: Router) -> Self {
394 self.host_routers.insert(host.into(), router);
395 self
396 }
397
398 fn add_cache_control() -> SetResponseHeaderLayer<HeaderValue> {
399 SetResponseHeaderLayer::if_not_present(
400 header::CACHE_CONTROL,
401 HeaderValue::from_static("max-age=300"),
402 )
403 }
404}
405
406pub struct HttpsConfig {
407 pub certificate: Vec<u8>,
408 pub private_key: Vec<u8>,
409}
410
411#[macro_export]
422macro_rules! https_pems {
423 ($path: literal) => {
424 #[derive(spa_rs::rust_embed::RustEmbed)]
425 #[folder = $path]
426 struct HttpsPems;
427 };
428
429 () => {{
430 let https_config = || -> anyhow::Result<spa_rs::HttpsConfig> {
431 let mut cert = Vec::new();
432 let mut key = Vec::new();
433 for file in HttpsPems::iter() {
434 if let Some(f) = HttpsPems::get(&file) {
435 macro_rules! setup {
436 ($t: expr) => {
437 if file == format!("{}.pem", stringify!($t)) {
438 $t = f.data.to_vec();
439 }
440 };
441 }
442 setup!(cert);
443 setup!(key);
444 }
445 }
446
447 if cert.is_empty() || key.is_empty() {
448 anyhow::bail!("invalid ssl cert or key embed file");
449 }
450
451 Ok(spa_rs::HttpsConfig {
452 certificate: cert,
453 private_key: key,
454 })
455 };
456 https_config()
457 }};
458}
459
460#[macro_export]
463macro_rules! spa_server_root {
464 ($root: literal) => {
465 use spa_rs::rust_embed;
466
467 #[derive(rust_embed::RustEmbed)]
468 #[folder = $root]
469 struct StaticFiles;
470
471 impl spa_rs::SpaStatic for StaticFiles {}
472 };
473 () => {
474 StaticFiles
475 };
476}
477
478pub trait SpaStatic: rust_embed::RustEmbed {
481 fn release(&self, release_path: PathBuf) -> Result<PathBuf> {
482 let target_dir = release_path;
483 if !target_dir.exists() {
484 create_dir_all(&target_dir)?;
485 }
486
487 for file in Self::iter() {
488 match Self::get(&file) {
489 Some(f) => {
490 if let Some(p) = Path::new(file.as_ref()).parent() {
491 let parent_dir = target_dir.join(p);
492 create_dir_all(parent_dir)?;
493 }
494
495 let path = target_dir.join(file.as_ref());
496 debug!("release static file: {}", path.display());
497 if let Err(e) = fs::write(path, f.data) {
498 error!("static file {} write error: {:?}", file, e);
499 }
500 }
501 None => warn!("static file {} not found", file),
502 }
503 }
504
505 Ok(target_dir)
506 }
507}
508
509impl SpaStatic for ApiOnly {}
510impl rust_embed::RustEmbed for ApiOnly {
511 fn get(_file_path: &str) -> Option<rust_embed::EmbeddedFile> {
512 unreachable!()
513 }
514
515 fn iter() -> rust_embed::Filenames {
516 unreachable!()
517 }
518}
519
520struct ApiOnly;