Skip to main content

rustauth_core/api/
openapi.rs

1use std::collections::BTreeMap;
2
3use http::Method;
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6
7use crate::api::additional_fields::AdditionalField as RuntimeAdditionalField;
8use crate::context::AuthContext;
9use crate::db::{DbField, DbFieldType, DbValue};
10
11use super::endpoint::AsyncAuthEndpoint;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct OpenApiOperation {
15    pub operation_id: Option<String>,
16    pub summary: Option<String>,
17    pub description: Option<String>,
18    pub tags: Vec<String>,
19    pub parameters: Vec<Value>,
20    pub request_body: Option<Value>,
21    pub responses: BTreeMap<String, Value>,
22}
23
24impl OpenApiOperation {
25    pub fn new(operation_id: impl Into<String>) -> Self {
26        Self {
27            operation_id: Some(operation_id.into()),
28            summary: None,
29            description: None,
30            tags: Vec::new(),
31            parameters: Vec::new(),
32            request_body: None,
33            responses: BTreeMap::new(),
34        }
35    }
36
37    #[must_use]
38    pub fn summary(mut self, summary: impl Into<String>) -> Self {
39        self.summary = Some(summary.into());
40        self
41    }
42
43    #[must_use]
44    pub fn description(mut self, description: impl Into<String>) -> Self {
45        self.description = Some(description.into());
46        self
47    }
48
49    #[must_use]
50    pub fn tag(mut self, tag: impl Into<String>) -> Self {
51        self.tags.push(tag.into());
52        self
53    }
54
55    #[must_use]
56    pub fn request_body(mut self, request_body: Value) -> Self {
57        self.request_body = Some(request_body);
58        self
59    }
60
61    #[must_use]
62    pub fn parameter(mut self, parameter: Value) -> Self {
63        self.parameters.push(parameter);
64        self
65    }
66
67    #[must_use]
68    pub fn response(mut self, status: impl Into<String>, response: Value) -> Self {
69        self.responses.insert(status.into(), response);
70        self
71    }
72}
73
74pub(super) fn openapi_operation_for_endpoint(endpoint: &AsyncAuthEndpoint) -> Value {
75    let mut operation = endpoint
76        .options
77        .openapi
78        .clone()
79        .unwrap_or_else(|| OpenApiOperation {
80            operation_id: endpoint.options.operation_id.clone(),
81            summary: None,
82            description: None,
83            tags: Vec::new(),
84            parameters: Vec::new(),
85            request_body: None,
86            responses: BTreeMap::new(),
87        });
88    let operation_id = operation
89        .operation_id
90        .clone()
91        .or_else(|| endpoint.options.operation_id.clone());
92    if operation.summary.is_none() {
93        operation.summary = operation_id.as_deref().map(humanize_operation_id);
94    }
95    if operation.description.is_none() {
96        operation.description = operation
97            .summary
98            .as_ref()
99            .map(|summary| format!("{summary} endpoint"));
100    }
101    add_missing_path_parameters(&mut operation.parameters, &endpoint.path);
102    let request_body = operation.request_body.or_else(|| {
103        endpoint
104            .options
105            .body_schema
106            .as_ref()
107            .map(|schema| {
108                json!({
109                    "required": true,
110                    "content": {
111                        "application/json": {
112                            "schema": schema.openapi_schema(),
113                        },
114                    },
115                })
116            })
117            .or_else(|| {
118                method_uses_request_body(&endpoint.method).then(|| {
119                    json!({
120                        "content": {
121                            "application/json": {
122                                "schema": {
123                                    "type": "object",
124                                    "properties": {},
125                                },
126                            },
127                        },
128                    })
129                })
130            })
131    });
132    let mut responses = default_openapi_responses();
133    for (status, response) in operation.responses {
134        responses.insert(status, response);
135    }
136    if !responses
137        .keys()
138        .any(|status| status.starts_with('2') || status.starts_with('3'))
139    {
140        responses.insert(
141            "200".to_owned(),
142            json_openapi_response(
143                "Success",
144                json!({
145                    "type": "object",
146                    "properties": {},
147                }),
148            ),
149        );
150    }
151    let mut tags = if operation.tags.is_empty() {
152        vec![tag_for_endpoint(endpoint, operation_id.as_deref())]
153    } else {
154        Vec::new()
155    };
156    for tag in operation.tags {
157        if !tags.iter().any(|existing| existing == &tag) {
158            tags.push(tag);
159        }
160    }
161
162    let mut value = serde_json::Map::new();
163    value.insert(
164        "tags".to_owned(),
165        Value::Array(tags.into_iter().map(Value::String).collect()),
166    );
167    if let Some(description) = operation.description {
168        value.insert("description".to_owned(), Value::String(description));
169    }
170    if let Some(summary) = operation.summary {
171        value.insert("summary".to_owned(), Value::String(summary));
172    }
173    if let Some(operation_id) = operation_id {
174        value.insert("operationId".to_owned(), Value::String(operation_id));
175    }
176    value.insert(
177        "security".to_owned(),
178        json!([
179            {
180                "bearerAuth": [],
181            },
182        ]),
183    );
184    value.insert("parameters".to_owned(), Value::Array(operation.parameters));
185    if let Some(request_body) = request_body {
186        value.insert("requestBody".to_owned(), request_body);
187    }
188    value.insert("responses".to_owned(), Value::Object(responses));
189    Value::Object(value)
190}
191
192fn add_missing_path_parameters(parameters: &mut Vec<Value>, path: &str) {
193    for name in path
194        .split('/')
195        .filter_map(|part| part.strip_prefix(':'))
196        .filter(|name| !name.is_empty())
197    {
198        let exists = parameters.iter().any(|parameter| {
199            parameter.get("name").and_then(Value::as_str) == Some(name)
200                && parameter.get("in").and_then(Value::as_str) == Some("path")
201        });
202        if !exists {
203            parameters.push(path_param(name, &format!("Path parameter `{name}`")));
204        }
205    }
206}
207
208fn humanize_operation_id(operation_id: &str) -> String {
209    let mut words = Vec::new();
210    let mut current = String::new();
211    for character in operation_id.chars() {
212        if character == '_' || character == '-' {
213            if !current.is_empty() {
214                words.push(std::mem::take(&mut current));
215            }
216            continue;
217        }
218        if character.is_uppercase() && !current.is_empty() {
219            words.push(std::mem::take(&mut current));
220        }
221        current.push(character.to_ascii_lowercase());
222    }
223    if !current.is_empty() {
224        words.push(current);
225    }
226
227    let mut summary = words.join(" ");
228    if let Some(first) = summary.get_mut(0..1) {
229        first.make_ascii_uppercase();
230    }
231    summary
232}
233
234fn tag_for_endpoint(endpoint: &AsyncAuthEndpoint, operation_id: Option<&str>) -> String {
235    if let Some(tag) = tag_for_operation_id(operation_id.unwrap_or_default()) {
236        return tag.to_owned();
237    }
238    let first_segment = endpoint
239        .path
240        .split('/')
241        .find(|segment| !segment.is_empty())
242        .unwrap_or_default();
243    tag_for_path_segment(first_segment)
244        .unwrap_or("Default")
245        .to_owned()
246}
247
248fn tag_for_operation_id(operation_id: &str) -> Option<&'static str> {
249    if operation_id.starts_with("mcp") || operation_id.starts_with("getMcp") {
250        Some("MCP")
251    } else if operation_id.contains("JWT")
252        || operation_id.contains("JSONWeb")
253        || operation_id.ends_with("JWT")
254    {
255        Some("JWT")
256    } else if operation_id.contains("OAuth2") {
257        Some("Generic OAuth")
258    } else if operation_id.contains("Siwe") {
259        Some("SIWE")
260    } else if operation_id.contains("PhoneNumber") {
261        Some("Phone Number")
262    } else if operation_id.contains("TwoFactor")
263        || operation_id.contains("BackupCode")
264        || operation_id.contains("Otp")
265    {
266        Some("Two Factor")
267    } else if operation_id.starts_with("organization") || operation_id.contains("Organization") {
268        Some("Organization")
269    } else {
270        None
271    }
272}
273
274fn tag_for_path_segment(segment: &str) -> Option<&'static str> {
275    match segment {
276        "mcp" => Some("MCP"),
277        "admin" => Some("Admin"),
278        "anonymous" | "delete-anonymous-user" => Some("Anonymous"),
279        "device" | "device-authorization" => Some("Device Authorization"),
280        "email-otp" => Some("Email OTP"),
281        "oauth2" => Some("Generic OAuth"),
282        "jwt" | "jwks" | "token" => Some("JWT"),
283        "magic-link" => Some("Magic Link"),
284        "multi-session" => Some("Multi Session"),
285        "oauth-proxy" => Some("OAuth Proxy"),
286        "one-tap" => Some("One Tap"),
287        "one-time-token" => Some("One Time Token"),
288        "open-api" => Some("Open API"),
289        "organization" => Some("Organization"),
290        "phone-number" => Some("Phone Number"),
291        "siwe" => Some("SIWE"),
292        "two-factor" => Some("Two Factor"),
293        "username" => Some("Username"),
294        _ => None,
295    }
296}
297
298pub fn build_openapi_schema(context: &AuthContext, async_endpoints: &[AsyncAuthEndpoint]) -> Value {
299    let mut paths = serde_json::Map::new();
300    for endpoint in async_endpoints {
301        if endpoint.options.server_only || endpoint.options.hide_from_openapi {
302            continue;
303        }
304        let path = paths
305            .entry(to_openapi_path(&endpoint.path))
306            .or_insert_with(|| Value::Object(serde_json::Map::new()));
307        let Value::Object(methods) = path else {
308            continue;
309        };
310        methods.insert(
311            endpoint.method.as_str().to_ascii_lowercase(),
312            openapi_operation_for_endpoint(endpoint),
313        );
314    }
315    json!({
316        "openapi": "3.1.1",
317        "info": {
318            "title": "RustAuth",
319            "description": "API Reference for your RustAuth instance",
320            "version": crate::VERSION,
321        },
322        "components": {
323            "schemas": openapi_model_schemas(context),
324            "securitySchemes": {
325                "apiKeyCookie": {
326                    "type": "apiKey",
327                    "in": "cookie",
328                    "name": "apiKeyCookie",
329                    "description": "API Key authentication via cookie",
330                },
331                "bearerAuth": {
332                    "type": "http",
333                    "scheme": "bearer",
334                    "description": "Bearer token authentication",
335                },
336            },
337        },
338        "security": [
339            {
340                "apiKeyCookie": [],
341                "bearerAuth": [],
342            },
343        ],
344        "servers": [
345            {
346                "url": context.base_url,
347            },
348        ],
349        "tags": [
350            {
351                "name": "Default",
352                "description": "Default endpoints that are included with RustAuth by default. These endpoints are not part of any plugin.",
353            },
354        ],
355        "paths": paths,
356    })
357}
358
359fn method_uses_request_body(method: &Method) -> bool {
360    matches!(*method, Method::POST | Method::PATCH | Method::PUT)
361}
362
363pub(super) fn to_openapi_path(path: &str) -> String {
364    path.split('/')
365        .map(|part| {
366            part.strip_prefix(':')
367                .map(|name| format!("{{{name}}}"))
368                .unwrap_or_else(|| part.to_owned())
369        })
370        .collect::<Vec<_>>()
371        .join("/")
372}
373
374fn default_openapi_responses() -> serde_json::Map<String, Value> {
375    let mut responses = serde_json::Map::new();
376    responses.insert(
377        "400".to_owned(),
378        openapi_error_response(
379            "Bad Request. Usually due to missing parameters, or invalid parameters.",
380            true,
381        ),
382    );
383    responses.insert(
384        "401".to_owned(),
385        openapi_error_response(
386            "Unauthorized. Due to missing or invalid authentication.",
387            true,
388        ),
389    );
390    responses.insert(
391        "403".to_owned(),
392        openapi_error_response(
393            "Forbidden. You do not have permission to access this resource or to perform this action.",
394            false,
395        ),
396    );
397    responses.insert(
398        "404".to_owned(),
399        openapi_error_response("Not Found. The requested resource was not found.", false),
400    );
401    responses.insert(
402        "429".to_owned(),
403        openapi_error_response(
404            "Too Many Requests. You have exceeded the rate limit. Try again later.",
405            false,
406        ),
407    );
408    responses.insert(
409        "500".to_owned(),
410        openapi_error_response(
411            "Internal Server Error. This is a problem with the server that you cannot fix.",
412            false,
413        ),
414    );
415    responses
416}
417
418fn openapi_error_response(description: &str, require_message: bool) -> Value {
419    let mut required = vec!["code"];
420    if require_message {
421        required.push("message");
422    }
423    let mut schema = serde_json::Map::new();
424    schema.insert("type".to_owned(), Value::String("object".to_owned()));
425    schema.insert(
426        "properties".to_owned(),
427        json!({
428            "code": {
429                "type": "string",
430            },
431            "message": {
432                "type": "string",
433            },
434            "originalMessage": {
435                "type": "string",
436            },
437        }),
438    );
439    schema.insert("required".to_owned(), json!(required));
440    json!({
441        "content": {
442            "application/json": {
443                "schema": Value::Object(schema),
444            },
445        },
446        "description": description,
447    })
448}
449
450pub fn json_openapi_response(description: &str, schema: Value) -> Value {
451    json!({
452        "description": description,
453        "content": {
454            "application/json": {
455                "schema": schema,
456            },
457        },
458    })
459}
460
461pub fn empty_openapi_response(description: &str) -> Value {
462    json!({
463        "description": description,
464    })
465}
466
467pub fn redirect_openapi_response(description: &str) -> Value {
468    json!({
469        "description": description,
470        "headers": {
471            "Location": {
472                "description": "Redirect target",
473                "schema": {
474                    "type": "string",
475                    "format": "uri",
476                },
477            },
478        },
479    })
480}
481
482pub fn query_param(name: &str, description: &str) -> Value {
483    json!({
484        "name": name,
485        "in": "query",
486        "required": false,
487        "description": description,
488        "schema": {
489            "type": "string",
490        },
491    })
492}
493
494pub fn path_param(name: &str, description: &str) -> Value {
495    json!({
496        "name": name,
497        "in": "path",
498        "required": true,
499        "description": description,
500        "schema": {
501            "type": "string",
502        },
503    })
504}
505
506pub(super) fn openapi_model_schemas(context: &AuthContext) -> Value {
507    let mut schemas = serde_json::Map::new();
508    for (logical_table, table) in context.db_schema.tables() {
509        let mut properties = serde_json::Map::new();
510        let mut required = Vec::new();
511        for (logical_field, field) in &table.fields {
512            let property_name = openapi_property_name(logical_field);
513            if field.required {
514                required.push(Value::String(property_name.clone()));
515            }
516            properties.insert(
517                property_name,
518                openapi_field_schema(context, logical_table, logical_field, field),
519            );
520        }
521        match logical_table {
522            "user" => append_runtime_additional_fields(
523                context,
524                logical_table,
525                &mut properties,
526                &mut required,
527                &context.options.user.additional_fields,
528            ),
529            "session" => append_runtime_additional_fields(
530                context,
531                logical_table,
532                &mut properties,
533                &mut required,
534                &context.options.session.additional_fields,
535            ),
536            _ => {}
537        }
538
539        schemas.insert(
540            openapi_schema_name(logical_table),
541            json!({
542                "type": "object",
543                "properties": properties,
544                "required": required,
545                "additionalProperties": true,
546            }),
547        );
548    }
549    Value::Object(schemas)
550}
551
552fn append_runtime_additional_fields<F>(
553    context: &AuthContext,
554    logical_table: &str,
555    properties: &mut serde_json::Map<String, Value>,
556    required: &mut Vec<Value>,
557    fields: &std::collections::BTreeMap<String, F>,
558) where
559    F: RuntimeAdditionalField,
560{
561    for (logical_field, field) in fields {
562        let property_name = openapi_property_name(logical_field);
563        if properties.contains_key(&property_name) {
564            continue;
565        }
566        let db_field = DbField {
567            name: field
568                .db_name()
569                .map(str::to_owned)
570                .unwrap_or_else(|| logical_field.clone()),
571            field_type: field.field_type().clone(),
572            required: field.required(),
573            unique: false,
574            index: false,
575            returned: field.returned(),
576            input: field.input(),
577            foreign_key: None,
578            generated_id: None,
579        };
580        if db_field.required {
581            required.push(Value::String(property_name.clone()));
582        }
583        properties.insert(
584            property_name,
585            openapi_field_schema(context, logical_table, logical_field, &db_field),
586        );
587    }
588}
589
590fn openapi_field_schema(
591    context: &AuthContext,
592    logical_table: &str,
593    logical_field: &str,
594    field: &DbField,
595) -> Value {
596    let mut schema = serde_json::Map::new();
597    let type_name = openapi_field_type(&field.field_type);
598    if field.required {
599        schema.insert("type".to_owned(), Value::String(type_name.to_owned()));
600    } else {
601        schema.insert("type".to_owned(), json!([type_name, "null"]));
602    }
603    match field.field_type {
604        DbFieldType::String => {
605            if logical_field == "email" {
606                schema.insert("format".to_owned(), Value::String("email".to_owned()));
607            } else if logical_field == "image" || logical_field == "logo" {
608                schema.insert("format".to_owned(), Value::String("uri".to_owned()));
609            }
610        }
611        DbFieldType::Timestamp => {
612            schema.insert("format".to_owned(), Value::String("date-time".to_owned()));
613        }
614        DbFieldType::StringArray => {
615            schema.insert("items".to_owned(), json!({ "type": "string" }));
616        }
617        DbFieldType::NumberArray => {
618            schema.insert("items".to_owned(), json!({ "type": "number" }));
619        }
620        DbFieldType::Number | DbFieldType::Boolean | DbFieldType::Json => {}
621    }
622    if !field.input {
623        schema.insert("readOnly".to_owned(), Value::Bool(true));
624    }
625    if let Some(default_value) = openapi_field_default(context, logical_table, logical_field) {
626        schema.insert("default".to_owned(), default_value);
627    }
628    Value::Object(schema)
629}
630
631fn openapi_field_type(field_type: &DbFieldType) -> &'static str {
632    match field_type {
633        DbFieldType::String | DbFieldType::Timestamp => "string",
634        DbFieldType::Number => "number",
635        DbFieldType::Boolean => "boolean",
636        DbFieldType::Json => "object",
637        DbFieldType::StringArray | DbFieldType::NumberArray => "array",
638    }
639}
640
641fn openapi_field_default(
642    context: &AuthContext,
643    logical_table: &str,
644    logical_field: &str,
645) -> Option<Value> {
646    let value = match logical_table {
647        "user" => context
648            .options
649            .user
650            .additional_fields
651            .get(logical_field)
652            .and_then(|field| field.default_value.as_ref()),
653        "session" => context
654            .options
655            .session
656            .additional_fields
657            .get(logical_field)
658            .and_then(|field| field.default_value.as_ref()),
659        _ => None,
660    }?;
661    db_value_to_openapi_default(value)
662}
663
664fn db_value_to_openapi_default(value: &DbValue) -> Option<Value> {
665    match value {
666        DbValue::String(value) => Some(Value::String(value.clone())),
667        DbValue::Number(value) => Some(Value::Number((*value).into())),
668        DbValue::Boolean(value) => Some(Value::Bool(*value)),
669        DbValue::Json(value) => Some(value.clone()),
670        DbValue::StringArray(values) => Some(Value::Array(
671            values.iter().cloned().map(Value::String).collect(),
672        )),
673        DbValue::NumberArray(values) => Some(Value::Array(
674            values
675                .iter()
676                .map(|value| Value::Number((*value).into()))
677                .collect(),
678        )),
679        DbValue::Null => Some(Value::Null),
680        DbValue::Timestamp(_) | DbValue::Record(_) | DbValue::RecordArray(_) => None,
681    }
682}
683
684fn openapi_schema_name(logical_table: &str) -> String {
685    match logical_table {
686        "user" => "User".to_owned(),
687        "session" => "Session".to_owned(),
688        "account" => "Account".to_owned(),
689        "verification" => "Verification".to_owned(),
690        "rate_limit" => "RateLimit".to_owned(),
691        "organization" => "Organization".to_owned(),
692        "member" => "Member".to_owned(),
693        "invitation" => "Invitation".to_owned(),
694        "team" => "Team".to_owned(),
695        "team_member" => "TeamMember".to_owned(),
696        "organization_role" => "OrganizationRole".to_owned(),
697        "wallet_address" => "WalletAddress".to_owned(),
698        value => pascal_case(value),
699    }
700}
701
702fn openapi_property_name(logical_field: &str) -> String {
703    snake_to_camel(logical_field)
704}
705
706fn snake_to_camel(value: &str) -> String {
707    let mut output = String::new();
708    let mut uppercase_next = false;
709    for character in value.chars() {
710        if character == '_' {
711            uppercase_next = true;
712            continue;
713        }
714        if uppercase_next {
715            output.extend(character.to_uppercase());
716            uppercase_next = false;
717        } else {
718            output.push(character);
719        }
720    }
721    output
722}
723
724fn pascal_case(value: &str) -> String {
725    let mut output = String::new();
726    let mut capitalize = true;
727    for character in value.chars() {
728        if matches!(character, '_' | '-' | ' ') {
729            capitalize = true;
730            continue;
731        }
732        if capitalize {
733            output.extend(character.to_uppercase());
734            capitalize = false;
735        } else {
736            output.push(character);
737        }
738    }
739    output
740}