utoipa_swagger_ui/
axum.rs

1#![cfg(feature = "axum")]
2
3use std::sync::Arc;
4
5use axum::{
6    body::Body,
7    extract::Path,
8    http::{HeaderMap, Request, Response, StatusCode},
9    middleware::{self, Next},
10    response::IntoResponse,
11    routing, Extension, Json, Router,
12};
13use base64::{prelude::BASE64_STANDARD, Engine};
14
15use crate::{ApiDoc, BasicAuth, Config, SwaggerUi, Url};
16
17impl<S> From<SwaggerUi> for Router<S>
18where
19    S: Clone + Send + Sync + 'static,
20{
21    fn from(swagger_ui: SwaggerUi) -> Self {
22        let urls_capacity = swagger_ui.urls.len();
23        let external_urls_capacity = swagger_ui.external_urls.len();
24
25        let (router, urls) = swagger_ui.urls.into_iter().fold(
26            (
27                Router::<S>::new(),
28                Vec::<Url>::with_capacity(urls_capacity + external_urls_capacity),
29            ),
30            |router_and_urls, (url, openapi)| {
31                add_api_doc_to_urls(router_and_urls, (url, Arc::new(ApiDoc::Utoipa(openapi))))
32            },
33        );
34        let (router, urls) = swagger_ui.external_urls.into_iter().fold(
35            (router, urls),
36            |router_and_urls, (url, openapi)| {
37                add_api_doc_to_urls(router_and_urls, (url, Arc::new(ApiDoc::Value(openapi))))
38            },
39        );
40
41        let config = if let Some(config) = swagger_ui.config {
42            if config.url.is_some() || !config.urls.is_empty() {
43                config
44            } else {
45                config.configure_defaults(urls)
46            }
47        } else {
48            Config::new(urls)
49        };
50
51        let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config.clone())));
52        let path: &str = swagger_ui.path.as_ref();
53
54        let mut router = if path == "/" {
55            router
56                .route(path, handler.clone())
57                .route(&format!("{}{{*rest}}", path), handler)
58        } else {
59            let path = if path.ends_with('/') {
60                &path[..path.len() - 1]
61            } else {
62                path
63            };
64            debug_assert!(!path.is_empty());
65
66            let slash_path = format!("{}/", path);
67            router
68                .route(
69                    path,
70                    routing::get(|| async move { axum::response::Redirect::to(&slash_path) }),
71                )
72                .route(&format!("{}/", path), handler.clone())
73                .route(&format!("{}/{{*rest}}", path), handler)
74        };
75
76        if let Some(BasicAuth { username, password }) = config.basic_auth {
77            let username = Arc::new(username);
78            let password = Arc::new(password);
79            let basic_auth_middleware =
80                move |headers: HeaderMap, req: Request<Body>, next: Next| {
81                    let username = username.clone();
82                    let password = password.clone();
83                    async move {
84                        if let Some(header) = headers.get("Authorization") {
85                            if let Ok(header_str) = header.to_str() {
86                                let base64_encoded_credentials =
87                                    BASE64_STANDARD.encode(format!("{}:{}", &username, &password));
88                                if header_str == format!("Basic {}", base64_encoded_credentials) {
89                                    return Ok::<Response<Body>, StatusCode>(next.run(req).await);
90                                }
91                            }
92                        }
93                        Ok::<Response<Body>, StatusCode>(
94                            (
95                                StatusCode::UNAUTHORIZED,
96                                [("WWW-Authenticate", "Basic realm=\":\"")],
97                            )
98                                .into_response(),
99                        )
100                    }
101                };
102            router = router.layer(middleware::from_fn(basic_auth_middleware));
103        }
104
105        router
106    }
107}
108
109fn add_api_doc_to_urls<S>(
110    router_and_urls: (Router<S>, Vec<Url<'static>>),
111    url: (Url<'static>, Arc<ApiDoc>),
112) -> (Router<S>, Vec<Url<'static>>)
113where
114    S: Clone + Send + Sync + 'static,
115{
116    let (router, mut urls) = router_and_urls;
117    let (url, openapi) = url;
118    (
119        router.route(
120            url.url.as_ref(),
121            routing::get(move || async { Json(openapi) }),
122        ),
123        {
124            urls.push(url);
125            urls
126        },
127    )
128}
129
130async fn serve_swagger_ui(
131    path: Option<Path<String>>,
132    Extension(state): Extension<Arc<Config<'static>>>,
133) -> impl IntoResponse {
134    let tail = match path.as_ref() {
135        Some(tail) => tail,
136        None => "",
137    };
138
139    match super::serve(tail, state) {
140        Ok(file) => file
141            .map(|file| {
142                (
143                    StatusCode::OK,
144                    [("Content-Type", file.content_type)],
145                    file.bytes,
146                )
147                    .into_response()
148            })
149            .unwrap_or_else(|| StatusCode::NOT_FOUND.into_response()),
150        Err(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response(),
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use http::header::AUTHORIZATION;
158    use http::HeaderValue;
159    use tower::util::ServiceExt;
160
161    #[tokio::test]
162    async fn mount_onto_root() {
163        let app = Router::<()>::from(SwaggerUi::new("/"));
164        let response = app.clone().oneshot(get("/")).await.unwrap();
165        assert_eq!(response.status(), StatusCode::OK);
166        let response = app.clone().oneshot(get("/swagger-ui.css")).await.unwrap();
167        assert_eq!(response.status(), StatusCode::OK);
168    }
169
170    #[tokio::test]
171    async fn mount_onto_path_ends_with_slash() {
172        let app = Router::<()>::from(SwaggerUi::new("/swagger-ui/"));
173        let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
174        assert_eq!(response.status(), StatusCode::SEE_OTHER);
175        let response = app.clone().oneshot(get("/swagger-ui/")).await.unwrap();
176        assert_eq!(response.status(), StatusCode::OK);
177        let request = get("/swagger-ui/swagger-ui.css");
178        let response = app.clone().oneshot(request).await.unwrap();
179        assert_eq!(response.status(), StatusCode::OK);
180    }
181
182    #[tokio::test]
183    async fn mount_onto_path_not_end_with_slash() {
184        let app = Router::<()>::from(SwaggerUi::new("/swagger-ui"));
185        let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
186        assert_eq!(response.status(), StatusCode::SEE_OTHER);
187        let response = app.clone().oneshot(get("/swagger-ui/")).await.unwrap();
188        assert_eq!(response.status(), StatusCode::OK);
189        let request = get("/swagger-ui/swagger-ui.css");
190        let response = app.clone().oneshot(request).await.unwrap();
191        assert_eq!(response.status(), StatusCode::OK);
192    }
193
194    #[tokio::test]
195    async fn basic_auth() {
196        let swagger_ui =
197            SwaggerUi::new("/swagger-ui").config(Config::default().basic_auth(BasicAuth {
198                username: "admin".to_string(),
199                password: "password".to_string(),
200            }));
201        let app = Router::<()>::from(swagger_ui);
202        let response = app.clone().oneshot(get("/swagger-ui")).await.unwrap();
203        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
204        let encoded_credentials = BASE64_STANDARD.encode("admin:password");
205        let authorization = format!("Basic {}", encoded_credentials);
206        let request = authorized_get("/swagger-ui", &authorization);
207        let response = app.clone().oneshot(request).await.unwrap();
208        assert_eq!(response.status(), StatusCode::SEE_OTHER);
209        let request = authorized_get("/swagger-ui/", &authorization);
210        let response = app.clone().oneshot(request).await.unwrap();
211        assert_eq!(response.status(), StatusCode::OK);
212        let request = authorized_get("/swagger-ui/swagger-ui.css", &authorization);
213        let response = app.clone().oneshot(request).await.unwrap();
214        assert_eq!(response.status(), StatusCode::OK);
215    }
216
217    fn get(url: &str) -> Request<Body> {
218        Request::builder().uri(url).body(Body::empty()).unwrap()
219    }
220
221    fn authorized_get(url: &str, authorization: &str) -> Request<Body> {
222        Request::builder()
223            .uri(url)
224            .header(AUTHORIZATION, HeaderValue::from_str(authorization).unwrap())
225            .body(Body::empty())
226            .unwrap()
227    }
228}