spring_web/
middleware.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
use crate::config::CorsMiddleware;
use crate::config::{
    EnableMiddleware, LimitPayloadMiddleware, Middlewares, StaticAssetsMiddleware,
    TimeoutRequestMiddleware, TraceLoggerMiddleware,
};
use anyhow::Context;
use axum::Router;
use spring::error::Result;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use tower_http::trace::DefaultMakeSpan;
use tower_http::trace::DefaultOnRequest;
use tower_http::trace::DefaultOnResponse;
use tower_http::{
    catch_panic::CatchPanicLayer,
    compression::CompressionLayer,
    cors::CorsLayer,
    limit::RequestBodyLimitLayer,
    services::{ServeDir, ServeFile},
    timeout::TimeoutLayer,
    trace::TraceLayer,
};
use trace::DefaultOnEos;

pub use tower_http::*;

pub(crate) fn apply_middleware(mut router: Router, middleware: Middlewares) -> Router {
    if Some(EnableMiddleware { enable: true }) == middleware.catch_panic {
        router = router.layer(CatchPanicLayer::new());
    }
    if Some(EnableMiddleware { enable: true }) == middleware.compression {
        router = router.layer(CompressionLayer::new());
    }
    if let Some(TraceLoggerMiddleware { enable, level }) = middleware.logger {
        if enable {
            let level = level.into();
            router = router.layer(
                TraceLayer::new_for_http()
                    .make_span_with(DefaultMakeSpan::default().level(level))
                    .on_request(DefaultOnRequest::default().level(level))
                    .on_response(DefaultOnResponse::default().level(level))
                    .on_eos(DefaultOnEos::default().level(level)),
            );
        }
    }
    if let Some(TimeoutRequestMiddleware { enable, timeout }) = middleware.timeout_request {
        if enable {
            router = router.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
        }
    }
    if let Some(LimitPayloadMiddleware { enable, body_limit }) = middleware.limit_payload {
        if enable {
            let limit = byte_unit::Byte::from_str(&body_limit)
                .unwrap_or_else(|_| panic!("parse limit payload str failed: {}", &body_limit));

            let limit_payload = RequestBodyLimitLayer::new(limit.as_u64() as usize);
            router = router.layer(limit_payload);
        }
    }
    if let Some(cors) = middleware.cors {
        if cors.enable {
            let cors = build_cors_middleware(&cors).expect("cors middleware build failed");
            router = router.layer(cors);
        }
    }
    if let Some(static_assets) = middleware.static_assets {
        if static_assets.enable {
            router = apply_static_dir(router, static_assets);
        }
    }
    router
}

fn apply_static_dir(router: Router, static_assets: StaticAssetsMiddleware) -> Router {
    if static_assets.must_exist
        && (!PathBuf::from(&static_assets.path).exists()
            || !PathBuf::from(&static_assets.fallback).exists())
    {
        panic!(
            "one of the static path are not found, Folder `{}` fallback: `{}`",
            static_assets.path, static_assets.fallback
        );
    }

    let serve_dir =
        ServeDir::new(static_assets.path).not_found_service(ServeFile::new(static_assets.fallback));

    let service = if static_assets.precompressed {
        tracing::info!("[Middleware] Enable precompressed static assets");
        serve_dir.precompressed_gzip()
    } else {
        serve_dir
    };

    router.nest_service(&static_assets.uri, service)
}

fn build_cors_middleware(cors: &CorsMiddleware) -> Result<CorsLayer> {
    let mut layer = CorsLayer::new();

    if let Some(allow_origins) = &cors.allow_origins {
        if allow_origins.iter().any(|item| item == "*") {
            layer = layer.allow_origin(cors::Any);
        } else {
            let mut origins = Vec::with_capacity(allow_origins.len());
            for origin in allow_origins {
                let origin = origin
                    .parse()
                    .with_context(|| format!("cors origin parse failed:{}", origin))?;
                origins.push(origin);
            }
            layer = layer.allow_origin(origins);
        }
    }

    if let Some(allow_headers) = &cors.allow_headers {
        if allow_headers.iter().any(|item| item == "*") {
            layer = layer.allow_headers(cors::Any);
        } else {
            let mut headers = Vec::with_capacity(allow_headers.len());
            for header in allow_headers {
                let header = header
                    .parse()
                    .with_context(|| format!("http header parse failed:{}", header))?;
                headers.push(header);
            }
            layer = layer.allow_headers(headers);
        }
    }

    if let Some(allow_methods) = &cors.allow_methods {
        if allow_methods.iter().any(|item| item == "*") {
            layer = layer.allow_methods(cors::Any);
        } else {
            let mut methods = Vec::with_capacity(allow_methods.len());
            for method in allow_methods {
                let method = method
                    .parse()
                    .with_context(|| format!("http method parse failed:{}", method))?;
                methods.push(method);
            }
            layer = layer.allow_methods(methods);
        }
    }

    if let Some(max_age) = cors.max_age {
        layer = layer.max_age(Duration::from_secs(max_age));
    }

    Ok(layer)
}