salvo_extra/
force_https.rs1use std::borrow::Cow;
46
47use salvo_core::handler::Skipper;
48use salvo_core::http::header;
49use salvo_core::http::uri::{Scheme, Uri};
50use salvo_core::http::{Request, ResBody, Response};
51use salvo_core::writing::Redirect;
52use salvo_core::{async_trait, Depot, FlowCtrl, Handler};
53
54#[derive(Default)]
56pub struct ForceHttps {
57 https_port: Option<u16>,
58 skipper: Option<Box<dyn Skipper>>,
59}
60impl ForceHttps {
61 pub fn new() -> Self {
63 Default::default()
64 }
65
66 pub fn https_port(self, port: u16) -> Self {
68 Self {
69 https_port: Some(port),
70 ..self
71 }
72 }
73
74 pub fn skipper(self, skipper: impl Skipper) -> Self {
76 Self {
77 skipper: Some(Box::new(skipper)),
78 ..self
79 }
80 }
81}
82
83#[async_trait]
84impl Handler for ForceHttps {
85 async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
86 if req.uri().scheme() == Some(&Scheme::HTTPS)
87 || self
88 .skipper
89 .as_ref()
90 .map(|skipper| skipper.skipped(req, depot))
91 .unwrap_or(false)
92 {
93 return;
94 }
95 if let Some(host) = req.header::<String>(header::HOST) {
96 let host = redirect_host(&host, self.https_port);
97 let uri_parts = std::mem::take(req.uri_mut()).into_parts();
98 let mut builder = Uri::builder().scheme(Scheme::HTTPS).authority(&*host);
99 if let Some(path_and_query) = uri_parts.path_and_query {
100 builder = builder.path_and_query(path_and_query);
101 }
102 if let Ok(uri) = builder.build() {
103 res.body(ResBody::None);
104 res.render(Redirect::permanent(uri));
105 ctrl.skip_rest();
106 }
107 }
108 }
109}
110
111fn redirect_host(host: &str, https_port: Option<u16>) -> Cow<'_, str> {
112 match (host.split_once(':'), https_port) {
113 (Some((host, _)), Some(port)) => Cow::Owned(format!("{host}:{port}")),
114 (None, Some(port)) => Cow::Owned(format!("{host}:{port}")),
115 (_, None) => Cow::Borrowed(host),
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use salvo_core::http::header::{HOST, LOCATION};
122 use salvo_core::prelude::*;
123 use salvo_core::test::TestClient;
124
125 use super::*;
126
127 #[test]
128 fn test_redirect_host() {
129 assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
130 assert_eq!(redirect_host("example.com:5678", Some(1234)), "example.com:1234");
131 assert_eq!(redirect_host("example.com", Some(1234)), "example.com:1234");
132 assert_eq!(redirect_host("example.com:1234", None), "example.com:1234");
133 assert_eq!(redirect_host("example.com", None), "example.com");
134 }
135
136 #[handler]
137 async fn hello() -> &'static str {
138 "Hello World"
139 }
140 #[tokio::test]
141 async fn test_redirect_handler() {
142 let router = Router::with_hoop(ForceHttps::new().https_port(1234)).goal(hello);
143 let response = TestClient::get("http://127.0.0.1:5800/")
144 .add_header(HOST, "127.0.0.1:5800", true)
145 .send(router)
146 .await;
147 assert_eq!(response.status_code, Some(StatusCode::PERMANENT_REDIRECT));
148 assert_eq!(
149 response.headers().get(LOCATION),
150 Some(&"https://127.0.0.1:1234/".parse().unwrap())
151 );
152 }
153}