salvo_extra/
force_https.rs

1//! Middleware force redirect to https.
2//!
3//! The force-https middleware can force all requests to use the HTTPS protocol.
4//!
5//! If this middleware is applied to the Router, the protocol will be forced to
6//! convert only when the route is matched. If the page does not exist, it will
7//! not be redirected.
8//!
9//! But the more common requirement is to expect any request to be
10//! automatically redirected, even when the route fails to match and returns a
11//! 404 error. At this time, the middleware can be added to the Service.
12//! Regardless of whether the request is successfully matched by the route,
13//! the middleware added to the Service will always be executed.
14//!
15//! Example:
16//!
17//! ```no_run
18//! use salvo_core::prelude::*;
19//! use salvo_core::conn::rustls::{Keycert, RustlsConfig};
20//! use salvo_extra::force_https::ForceHttps;
21//!
22//! #[handler]
23//! async fn hello() -> &'static str {
24//!     "hello"
25//! }
26//!
27//! #[tokio::main]
28//! async fn main() {
29//!     let router = Router::new().get(hello);
30//!     let service = Service::new(router).hoop(ForceHttps::new().https_port(5443));
31//!
32//!     let config = RustlsConfig::new(
33//!         Keycert::new()
34//!             .cert(include_bytes!("../../core/certs/cert.pem").as_ref())
35//!             .key(include_bytes!("../../core/certs/key.pem").as_ref()),
36//!     );
37//!     let acceptor = TcpListener::new("0.0.0.0:5443")
38//!         .rustls(config)
39//!         .join(TcpListener::new("0.0.0.0:5800"))
40//!         .bind()
41//!         .await;
42//!     Server::new(acceptor).serve(service).await;
43//! }
44//! ```
45use std::borrow::Cow;
46
47use salvo_core::handler::Skipper;
48use salvo_core::http::header;
49use salvo_core::http::uri::{Scheme, Uri};
50use salvo_core::http::{Request, ResBody, Response};
51use salvo_core::writing::Redirect;
52use salvo_core::{async_trait, Depot, FlowCtrl, Handler};
53
54/// Middleware for force redirect to http uri.
55#[derive(Default)]
56pub struct ForceHttps {
57    https_port: Option<u16>,
58    skipper: Option<Box<dyn Skipper>>,
59}
60impl ForceHttps {
61    /// Create new `ForceHttps` middleware.
62    pub fn new() -> Self {
63        Default::default()
64    }
65
66    /// Specify https port.
67    pub fn https_port(self, port: u16) -> Self {
68        Self {
69            https_port: Some(port),
70            ..self
71        }
72    }
73
74    /// Uses a closure to determine if a request should be redirect.
75    pub fn skipper(self, skipper: impl Skipper) -> Self {
76        Self {
77            skipper: Some(Box::new(skipper)),
78            ..self
79        }
80    }
81}
82
83#[async_trait]
84impl Handler for ForceHttps {
85    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
86        if req.uri().scheme() == Some(&Scheme::HTTPS)
87            || self
88                .skipper
89                .as_ref()
90                .map(|skipper| skipper.skipped(req, depot))
91                .unwrap_or(false)
92        {
93            return;
94        }
95        if let Some(host) = req.header::<String>(header::HOST) {
96            let host = redirect_host(&host, self.https_port);
97            let uri_parts = std::mem::take(req.uri_mut()).into_parts();
98            let mut builder = Uri::builder().scheme(Scheme::HTTPS).authority(&*host);
99            if let Some(path_and_query) = uri_parts.path_and_query {
100                builder = builder.path_and_query(path_and_query);
101            }
102            if let Ok(uri) = builder.build() {
103                res.body(ResBody::None);
104                res.render(Redirect::permanent(uri));
105                ctrl.skip_rest();
106            }
107        }
108    }
109}
110
111fn redirect_host(host: &str, https_port: Option<u16>) -> Cow<'_, str> {
112    match (host.split_once(':'), https_port) {
113        (Some((host, _)), Some(port)) => Cow::Owned(format!("{host}:{port}")),
114        (None, Some(port)) => Cow::Owned(format!("{host}:{port}")),
115        (_, None) => Cow::Borrowed(host),
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use salvo_core::http::header::{HOST, LOCATION};
122    use salvo_core::prelude::*;
123    use salvo_core::test::TestClient;
124
125    use super::*;
126
127    #[test]
128    fn test_redirect_host() {
129        assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
130        assert_eq!(redirect_host("example.com:5678", Some(1234)), "example.com:1234");
131        assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
132        assert_eq!(redirect_host("example.com:1234", None), "example.com:1234");
133        assert_eq!(redirect_host("example.com", None), "example.com");
134    }
135
136    #[handler]
137    async fn hello() -> &'static str {
138        "Hello World"
139    }
140    #[tokio::test]
141    async fn test_redirect_handler() {
142        let router = Router::with_hoop(ForceHttps::new().https_port(1234)).goal(hello);
143        let response = TestClient::get("http://127.0.0.1:5800/")
144            .add_header(HOST, "127.0.0.1:5800", true)
145            .send(router)
146            .await;
147        assert_eq!(response.status_code, Some(StatusCode::PERMANENT_REDIRECT));
148        assert_eq!(
149            response.headers().get(LOCATION),
150            Some(&"https://127.0.0.1:1234/".parse().unwrap())
151        );
152    }
153}