1use crate::config::CorsMiddleware;
2use crate::config::{
3 EnableMiddleware, LimitPayloadMiddleware, Middlewares, StaticAssetsMiddleware,
4 TimeoutRequestMiddleware, TraceLoggerMiddleware,
5};
6use crate::Router;
7use anyhow::Context;
8use axum::http::StatusCode;
9use spring::error::Result;
10use std::path::PathBuf;
11use std::str::FromStr;
12use std::time::Duration;
13use tower_http::trace::DefaultMakeSpan;
14use tower_http::trace::DefaultOnRequest;
15use tower_http::trace::DefaultOnResponse;
16use tower_http::{
17 catch_panic::CatchPanicLayer,
18 compression::CompressionLayer,
19 cors::CorsLayer,
20 limit::RequestBodyLimitLayer,
21 services::{ServeDir, ServeFile},
22 timeout::TimeoutLayer,
23 trace::TraceLayer,
24};
25use trace::DefaultOnEos;
26
27pub use tower_http::*;
28
29pub(crate) fn apply_middleware(mut router: Router, middleware: Middlewares) -> Router {
30 router = router.layer(axum::middleware::from_fn(crate::problem_details::capture_request_uri_middleware));
32
33 if Some(EnableMiddleware { enable: true }) == middleware.catch_panic {
34 router = router.layer(CatchPanicLayer::new());
35 }
36 if Some(EnableMiddleware { enable: true }) == middleware.compression {
37 router = router.layer(CompressionLayer::new());
38 }
39 if let Some(TraceLoggerMiddleware { enable, level }) = middleware.logger {
40 if enable {
41 let level = level.into();
42 router = router.layer(
43 TraceLayer::new_for_http()
44 .make_span_with(DefaultMakeSpan::default().level(level))
45 .on_request(DefaultOnRequest::default().level(level))
46 .on_response(DefaultOnResponse::default().level(level))
47 .on_eos(DefaultOnEos::default().level(level)),
48 );
49 }
50 }
51 if let Some(TimeoutRequestMiddleware { enable, timeout }) = middleware.timeout_request {
52 if enable {
53 router = router.layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_millis(timeout)));
54 }
55 }
56 if let Some(LimitPayloadMiddleware { enable, body_limit }) = middleware.limit_payload {
57 if enable {
58 let limit = byte_unit::Byte::from_str(&body_limit)
59 .unwrap_or_else(|_| panic!("parse limit payload str failed: {}", &body_limit));
60
61 let limit_payload = RequestBodyLimitLayer::new(limit.as_u64() as usize);
62 router = router.layer(limit_payload);
63 }
64 }
65 if let Some(cors) = middleware.cors {
66 if cors.enable {
67 let cors = build_cors_middleware(&cors).expect("cors middleware build failed");
68 router = router.layer(cors);
69 }
70 }
71 if let Some(static_assets) = middleware.static_assets {
72 if static_assets.enable {
73 router = apply_static_dir(router, static_assets);
74 }
75 }
76 router
77}
78
79fn apply_static_dir(router: Router, static_assets: StaticAssetsMiddleware) -> Router {
80 if static_assets.must_exist
81 && (!PathBuf::from(&static_assets.path).exists()
82 || !PathBuf::from(&static_assets.fallback).exists())
83 {
84 panic!(
85 "one of the static path are not found, Folder `{}` fallback: `{}`",
86 static_assets.path, static_assets.fallback
87 );
88 }
89
90 let fallback = ServeFile::new(format!("{}/{}", static_assets.path, static_assets.fallback));
91 let serve_dir = ServeDir::new(static_assets.path).not_found_service(fallback);
92
93 let service = if static_assets.precompressed {
94 tracing::info!("[Middleware] Enable precompressed static assets");
95 serve_dir.precompressed_gzip()
96 } else {
97 serve_dir
98 };
99
100 if static_assets.uri == "/" {
101 router.fallback_service(service)
102 } else {
103 router.nest_service(&static_assets.uri, service)
104 }
105}
106
107fn build_cors_middleware(cors: &CorsMiddleware) -> Result<CorsLayer> {
108 let mut layer = CorsLayer::new();
109
110 if let Some(allow_origins) = &cors.allow_origins {
111 if allow_origins.iter().any(|item| item == "*") {
112 layer = layer.allow_origin(cors::Any);
113 } else {
114 let mut origins = Vec::with_capacity(allow_origins.len());
115 for origin in allow_origins {
116 let origin = origin
117 .parse()
118 .with_context(|| format!("cors origin parse failed:{origin}"))?;
119 origins.push(origin);
120 }
121 layer = layer.allow_origin(origins);
122 }
123 }
124
125 if let Some(allow_headers) = &cors.allow_headers {
126 if allow_headers.iter().any(|item| item == "*") {
127 layer = layer.allow_headers(cors::Any);
128 } else {
129 let mut headers = Vec::with_capacity(allow_headers.len());
130 for header in allow_headers {
131 let header = header
132 .parse()
133 .with_context(|| format!("http header parse failed:{header}"))?;
134 headers.push(header);
135 }
136 layer = layer.allow_headers(headers);
137 }
138 }
139
140 if let Some(allow_methods) = &cors.allow_methods {
141 if allow_methods.iter().any(|item| item == "*") {
142 layer = layer.allow_methods(cors::Any);
143 } else {
144 let mut methods = Vec::with_capacity(allow_methods.len());
145 for method in allow_methods {
146 let method = method
147 .parse()
148 .with_context(|| format!("http method parse failed:{method}"))?;
149 methods.push(method);
150 }
151 layer = layer.allow_methods(methods);
152 }
153 }
154
155 if let Some(max_age) = cors.max_age {
156 layer = layer.max_age(Duration::from_secs(max_age));
157 }
158
159 Ok(layer)
160}