1use crate::config::CorsMiddleware;
2use crate::config::{
3 EnableMiddleware, LimitPayloadMiddleware, Middlewares, StaticAssetsMiddleware,
4 TimeoutRequestMiddleware, TraceLoggerMiddleware,
5};
6use anyhow::Context;
7use axum::Router;
8use spring::error::Result;
9use std::path::PathBuf;
10use std::str::FromStr;
11use std::time::Duration;
12use tower_http::trace::DefaultMakeSpan;
13use tower_http::trace::DefaultOnRequest;
14use tower_http::trace::DefaultOnResponse;
15use tower_http::{
16 catch_panic::CatchPanicLayer,
17 compression::CompressionLayer,
18 cors::CorsLayer,
19 limit::RequestBodyLimitLayer,
20 services::{ServeDir, ServeFile},
21 timeout::TimeoutLayer,
22 trace::TraceLayer,
23};
24use trace::DefaultOnEos;
25
26pub use tower_http::*;
27
28pub(crate) fn apply_middleware(mut router: Router, middleware: Middlewares) -> Router {
29 if Some(EnableMiddleware { enable: true }) == middleware.catch_panic {
30 router = router.layer(CatchPanicLayer::new());
31 }
32 if Some(EnableMiddleware { enable: true }) == middleware.compression {
33 router = router.layer(CompressionLayer::new());
34 }
35 if let Some(TraceLoggerMiddleware { enable, level }) = middleware.logger {
36 if enable {
37 let level = level.into();
38 router = router.layer(
39 TraceLayer::new_for_http()
40 .make_span_with(DefaultMakeSpan::default().level(level))
41 .on_request(DefaultOnRequest::default().level(level))
42 .on_response(DefaultOnResponse::default().level(level))
43 .on_eos(DefaultOnEos::default().level(level)),
44 );
45 }
46 }
47 if let Some(TimeoutRequestMiddleware { enable, timeout }) = middleware.timeout_request {
48 if enable {
49 router = router.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
50 }
51 }
52 if let Some(LimitPayloadMiddleware { enable, body_limit }) = middleware.limit_payload {
53 if enable {
54 let limit = byte_unit::Byte::from_str(&body_limit)
55 .unwrap_or_else(|_| panic!("parse limit payload str failed: {}", &body_limit));
56
57 let limit_payload = RequestBodyLimitLayer::new(limit.as_u64() as usize);
58 router = router.layer(limit_payload);
59 }
60 }
61 if let Some(cors) = middleware.cors {
62 if cors.enable {
63 let cors = build_cors_middleware(&cors).expect("cors middleware build failed");
64 router = router.layer(cors);
65 }
66 }
67 if let Some(static_assets) = middleware.static_assets {
68 if static_assets.enable {
69 router = apply_static_dir(router, static_assets);
70 }
71 }
72 router
73}
74
75fn apply_static_dir(router: Router, static_assets: StaticAssetsMiddleware) -> Router {
76 if static_assets.must_exist
77 && (!PathBuf::from(&static_assets.path).exists()
78 || !PathBuf::from(&static_assets.fallback).exists())
79 {
80 panic!(
81 "one of the static path are not found, Folder `{}` fallback: `{}`",
82 static_assets.path, static_assets.fallback
83 );
84 }
85
86 let fallback = ServeFile::new(format!("{}/{}", static_assets.path, static_assets.fallback));
87 let serve_dir = ServeDir::new(static_assets.path).not_found_service(fallback);
88
89 let service = if static_assets.precompressed {
90 tracing::info!("[Middleware] Enable precompressed static assets");
91 serve_dir.precompressed_gzip()
92 } else {
93 serve_dir
94 };
95
96 if static_assets.uri == "/" {
97 router.fallback_service(service)
98 } else {
99 router.nest_service(&static_assets.uri, service)
100 }
101}
102
103fn build_cors_middleware(cors: &CorsMiddleware) -> Result<CorsLayer> {
104 let mut layer = CorsLayer::new();
105
106 if let Some(allow_origins) = &cors.allow_origins {
107 if allow_origins.iter().any(|item| item == "*") {
108 layer = layer.allow_origin(cors::Any);
109 } else {
110 let mut origins = Vec::with_capacity(allow_origins.len());
111 for origin in allow_origins {
112 let origin = origin
113 .parse()
114 .with_context(|| format!("cors origin parse failed:{}", origin))?;
115 origins.push(origin);
116 }
117 layer = layer.allow_origin(origins);
118 }
119 }
120
121 if let Some(allow_headers) = &cors.allow_headers {
122 if allow_headers.iter().any(|item| item == "*") {
123 layer = layer.allow_headers(cors::Any);
124 } else {
125 let mut headers = Vec::with_capacity(allow_headers.len());
126 for header in allow_headers {
127 let header = header
128 .parse()
129 .with_context(|| format!("http header parse failed:{}", header))?;
130 headers.push(header);
131 }
132 layer = layer.allow_headers(headers);
133 }
134 }
135
136 if let Some(allow_methods) = &cors.allow_methods {
137 if allow_methods.iter().any(|item| item == "*") {
138 layer = layer.allow_methods(cors::Any);
139 } else {
140 let mut methods = Vec::with_capacity(allow_methods.len());
141 for method in allow_methods {
142 let method = method
143 .parse()
144 .with_context(|| format!("http method parse failed:{}", method))?;
145 methods.push(method);
146 }
147 layer = layer.allow_methods(methods);
148 }
149 }
150
151 if let Some(max_age) = cors.max_age {
152 layer = layer.max_age(Duration::from_secs(max_age));
153 }
154
155 Ok(layer)
156}