Skip to main content

rustauth_plugins/open_api/
mod.rs

1//! OpenAPI schema and reference plugin.
2
3use http::{header, Method, StatusCode};
4use rustauth_core::api::{
5    api_error, build_openapi_schema, core_auth_async_endpoints, create_auth_endpoint, ApiErrorCode,
6    ApiResponse, AsyncAuthEndpoint, AuthEndpointOptions, OpenApiOperation,
7};
8use rustauth_core::context::AuthContext;
9use rustauth_core::error::RustAuthError;
10use rustauth_core::plugin::AuthPlugin;
11use serde::{Deserialize, Serialize};
12use serde_json::json;
13
14pub const UPSTREAM_PLUGIN_ID: &str = "open-api";
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct OpenApiOptions {
18    pub path: String,
19    pub disable_default_reference: bool,
20    pub theme: String,
21    pub nonce: Option<String>,
22}
23
24impl Default for OpenApiOptions {
25    fn default() -> Self {
26        Self {
27            path: "/reference".to_owned(),
28            disable_default_reference: false,
29            theme: "default".to_owned(),
30            nonce: None,
31        }
32    }
33}
34
35impl OpenApiOptions {
36    #[must_use]
37    pub fn builder() -> OpenApiOptionsBuilder {
38        OpenApiOptionsBuilder::default()
39    }
40
41    #[must_use]
42    pub fn path(mut self, path: impl Into<String>) -> Self {
43        self.path = normalize_path(path.into());
44        self
45    }
46
47    #[must_use]
48    pub fn disable_default_reference(mut self, disabled: bool) -> Self {
49        self.disable_default_reference = disabled;
50        self
51    }
52
53    #[must_use]
54    pub fn theme(mut self, theme: impl Into<String>) -> Self {
55        self.theme = theme.into();
56        self
57    }
58
59    #[must_use]
60    pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
61        self.nonce = Some(nonce.into());
62        self
63    }
64}
65
66#[derive(Debug, Clone, Default)]
67pub struct OpenApiOptionsBuilder {
68    path: Option<String>,
69    disable_default_reference: Option<bool>,
70    theme: Option<String>,
71    nonce: Option<Option<String>>,
72}
73
74impl OpenApiOptionsBuilder {
75    #[must_use]
76    pub fn path(mut self, path: impl Into<String>) -> Self {
77        self.path = Some(normalize_path(path.into()));
78        self
79    }
80
81    #[must_use]
82    pub fn disable_default_reference(mut self, disabled: bool) -> Self {
83        self.disable_default_reference = Some(disabled);
84        self
85    }
86
87    #[must_use]
88    pub fn theme(mut self, theme: impl Into<String>) -> Self {
89        self.theme = Some(theme.into());
90        self
91    }
92
93    #[must_use]
94    pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
95        self.nonce = Some(Some(nonce.into()));
96        self
97    }
98
99    #[must_use]
100    pub fn build(self) -> OpenApiOptions {
101        let defaults = OpenApiOptions::default();
102        OpenApiOptions {
103            path: self.path.unwrap_or(defaults.path),
104            disable_default_reference: self
105                .disable_default_reference
106                .unwrap_or(defaults.disable_default_reference),
107            theme: self.theme.unwrap_or(defaults.theme),
108            nonce: self.nonce.unwrap_or(defaults.nonce),
109        }
110    }
111}
112
113#[must_use]
114pub fn open_api(options: OpenApiOptions) -> AuthPlugin {
115    AuthPlugin::new(UPSTREAM_PLUGIN_ID)
116        .with_version(crate::VERSION)
117        .with_options(serde_json::to_value(&options).unwrap_or(serde_json::Value::Null))
118        .with_endpoint(generate_schema_endpoint())
119        .with_endpoint(reference_endpoint(options))
120}
121
122fn generate_schema_endpoint() -> AsyncAuthEndpoint {
123    create_auth_endpoint(
124        "/open-api/generate-schema",
125        Method::GET,
126        AuthEndpointOptions::new()
127            .operation_id("generateOpenAPISchema")
128            .openapi(
129                OpenApiOperation::new("generateOpenAPISchema")
130                    .description("Generate the OpenAPI schema for this RustAuth instance")
131                    .response(
132                        "200",
133                        json!({
134                            "description": "OpenAPI schema",
135                            "content": {
136                                "application/json": {
137                                    "schema": {
138                                        "type": "object"
139                                    }
140                                }
141                            }
142                        }),
143                    ),
144            ),
145        |context, _request| async move {
146            json_response(
147                StatusCode::OK,
148                serde_json::to_vec(&schema_for_context(&context))
149                    .map_err(|error| RustAuthError::Api(error.to_string()))?,
150            )
151        },
152    )
153}
154
155fn reference_endpoint(options: OpenApiOptions) -> AsyncAuthEndpoint {
156    let path = options.path.clone();
157    create_auth_endpoint(
158        path,
159        Method::GET,
160        AuthEndpointOptions::new()
161            .operation_id("openApiReference")
162            .hide_from_openapi()
163            .openapi(
164                OpenApiOperation::new("openApiReference")
165                    .summary("OpenAPI reference")
166                    .description("Serve the interactive OpenAPI reference"),
167            ),
168        move |context, _request| {
169            let options = options.clone();
170            async move {
171                if options.disable_default_reference {
172                    return api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound);
173                }
174                html_response(get_html(
175                    &schema_for_context(&context),
176                    &options.theme,
177                    options.nonce.as_deref(),
178                ))
179            }
180        },
181    )
182}
183
184fn schema_for_context(context: &AuthContext) -> serde_json::Value {
185    let mut endpoints = core_auth_async_endpoints();
186    for plugin in &context.plugins {
187        endpoints.extend(plugin.endpoints.iter().cloned());
188    }
189    build_openapi_schema(context, &endpoints)
190}
191
192fn get_html(api_reference: &serde_json::Value, theme: &str, nonce: Option<&str>) -> String {
193    let nonce_attr = nonce
194        .map(|nonce| format!(" nonce=\"{}\"", escape_html_attr(nonce)))
195        .unwrap_or_default();
196    let api_reference = escape_script_json(&api_reference.to_string());
197    let theme = escape_js_string(theme);
198    format!(
199        r#"<!doctype html>
200<html>
201  <head>
202    <title>RustAuth API Reference</title>
203    <meta charset="utf-8" />
204    <meta name="viewport" content="width=device-width, initial-scale=1" />
205  </head>
206  <body>
207    <script id="api-reference" type="application/json">{api_reference}</script>
208    <script{nonce_attr}>
209      var configuration = {{
210        theme: "{theme}",
211        metaData: {{
212          title: "RustAuth API",
213          description: "API Reference for your RustAuth instance"
214        }}
215      }}
216      document.getElementById("api-reference").dataset.configuration =
217        JSON.stringify(configuration)
218    </script>
219    <script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"{nonce_attr}></script>
220  </body>
221</html>"#,
222        api_reference = api_reference,
223        theme = theme,
224        nonce_attr = nonce_attr,
225    )
226}
227
228fn json_response(status: StatusCode, body: Vec<u8>) -> Result<ApiResponse, RustAuthError> {
229    http::Response::builder()
230        .status(status)
231        .header(header::CONTENT_TYPE, "application/json")
232        .body(body)
233        .map_err(|error| RustAuthError::Api(error.to_string()))
234}
235
236fn html_response(body: String) -> Result<ApiResponse, RustAuthError> {
237    http::Response::builder()
238        .status(StatusCode::OK)
239        .header(header::CONTENT_TYPE, "text/html; charset=utf-8")
240        .body(body.into_bytes())
241        .map_err(|error| RustAuthError::Api(error.to_string()))
242}
243
244fn normalize_path(path: String) -> String {
245    if path.starts_with('/') {
246        path
247    } else {
248        format!("/{path}")
249    }
250}
251
252fn escape_html_attr(value: &str) -> String {
253    value
254        .replace('&', "&amp;")
255        .replace('"', "&quot;")
256        .replace('<', "&lt;")
257        .replace('>', "&gt;")
258}
259
260fn escape_js_string(value: &str) -> String {
261    let mut escaped = String::with_capacity(value.len());
262    for character in value.chars() {
263        match character {
264            '\\' => escaped.push_str("\\\\"),
265            '"' => escaped.push_str("\\\""),
266            '&' => escaped.push_str("\\u0026"),
267            '<' => escaped.push_str("\\u003c"),
268            '>' => escaped.push_str("\\u003e"),
269            '\u{2028}' => escaped.push_str("\\u2028"),
270            '\u{2029}' => escaped.push_str("\\u2029"),
271            character => escaped.push(character),
272        }
273    }
274    escaped
275}
276
277fn escape_script_json(value: &str) -> String {
278    value
279        .replace('&', "\\u0026")
280        .replace('<', "\\u003c")
281        .replace('>', "\\u003e")
282        .replace('\u{2028}', "\\u2028")
283        .replace('\u{2029}', "\\u2029")
284}