1use std::collections::BTreeMap;
2
3use http::Method;
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6
7use crate::api::additional_fields::AdditionalField as RuntimeAdditionalField;
8use crate::context::AuthContext;
9use crate::db::{DbField, DbFieldType, DbValue};
10
11use super::endpoint::AsyncAuthEndpoint;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct OpenApiOperation {
15 pub operation_id: Option<String>,
16 pub summary: Option<String>,
17 pub description: Option<String>,
18 pub tags: Vec<String>,
19 pub parameters: Vec<Value>,
20 pub request_body: Option<Value>,
21 pub responses: BTreeMap<String, Value>,
22}
23
24impl OpenApiOperation {
25 pub fn new(operation_id: impl Into<String>) -> Self {
26 Self {
27 operation_id: Some(operation_id.into()),
28 summary: None,
29 description: None,
30 tags: Vec::new(),
31 parameters: Vec::new(),
32 request_body: None,
33 responses: BTreeMap::new(),
34 }
35 }
36
37 #[must_use]
38 pub fn summary(mut self, summary: impl Into<String>) -> Self {
39 self.summary = Some(summary.into());
40 self
41 }
42
43 #[must_use]
44 pub fn description(mut self, description: impl Into<String>) -> Self {
45 self.description = Some(description.into());
46 self
47 }
48
49 #[must_use]
50 pub fn tag(mut self, tag: impl Into<String>) -> Self {
51 self.tags.push(tag.into());
52 self
53 }
54
55 #[must_use]
56 pub fn request_body(mut self, request_body: Value) -> Self {
57 self.request_body = Some(request_body);
58 self
59 }
60
61 #[must_use]
62 pub fn parameter(mut self, parameter: Value) -> Self {
63 self.parameters.push(parameter);
64 self
65 }
66
67 #[must_use]
68 pub fn response(mut self, status: impl Into<String>, response: Value) -> Self {
69 self.responses.insert(status.into(), response);
70 self
71 }
72}
73
74pub(super) fn openapi_operation_for_endpoint(endpoint: &AsyncAuthEndpoint) -> Value {
75 let mut operation = endpoint
76 .options
77 .openapi
78 .clone()
79 .unwrap_or_else(|| OpenApiOperation {
80 operation_id: endpoint.options.operation_id.clone(),
81 summary: None,
82 description: None,
83 tags: Vec::new(),
84 parameters: Vec::new(),
85 request_body: None,
86 responses: BTreeMap::new(),
87 });
88 let operation_id = operation
89 .operation_id
90 .clone()
91 .or_else(|| endpoint.options.operation_id.clone());
92 if operation.summary.is_none() {
93 operation.summary = operation_id.as_deref().map(humanize_operation_id);
94 }
95 if operation.description.is_none() {
96 operation.description = operation
97 .summary
98 .as_ref()
99 .map(|summary| format!("{summary} endpoint"));
100 }
101 add_missing_path_parameters(&mut operation.parameters, &endpoint.path);
102 let request_body = operation.request_body.or_else(|| {
103 endpoint
104 .options
105 .body_schema
106 .as_ref()
107 .map(|schema| {
108 json!({
109 "required": true,
110 "content": {
111 "application/json": {
112 "schema": schema.openapi_schema(),
113 },
114 },
115 })
116 })
117 .or_else(|| {
118 method_uses_request_body(&endpoint.method).then(|| {
119 json!({
120 "content": {
121 "application/json": {
122 "schema": {
123 "type": "object",
124 "properties": {},
125 },
126 },
127 },
128 })
129 })
130 })
131 });
132 let mut responses = default_openapi_responses();
133 for (status, response) in operation.responses {
134 responses.insert(status, response);
135 }
136 if !responses
137 .keys()
138 .any(|status| status.starts_with('2') || status.starts_with('3'))
139 {
140 responses.insert(
141 "200".to_owned(),
142 json_openapi_response(
143 "Success",
144 json!({
145 "type": "object",
146 "properties": {},
147 }),
148 ),
149 );
150 }
151 let mut tags = if operation.tags.is_empty() {
152 vec![tag_for_endpoint(endpoint, operation_id.as_deref())]
153 } else {
154 Vec::new()
155 };
156 for tag in operation.tags {
157 if !tags.iter().any(|existing| existing == &tag) {
158 tags.push(tag);
159 }
160 }
161
162 let mut value = serde_json::Map::new();
163 value.insert(
164 "tags".to_owned(),
165 Value::Array(tags.into_iter().map(Value::String).collect()),
166 );
167 if let Some(description) = operation.description {
168 value.insert("description".to_owned(), Value::String(description));
169 }
170 if let Some(summary) = operation.summary {
171 value.insert("summary".to_owned(), Value::String(summary));
172 }
173 if let Some(operation_id) = operation_id {
174 value.insert("operationId".to_owned(), Value::String(operation_id));
175 }
176 value.insert(
177 "security".to_owned(),
178 json!([
179 {
180 "bearerAuth": [],
181 },
182 ]),
183 );
184 value.insert("parameters".to_owned(), Value::Array(operation.parameters));
185 if let Some(request_body) = request_body {
186 value.insert("requestBody".to_owned(), request_body);
187 }
188 value.insert("responses".to_owned(), Value::Object(responses));
189 Value::Object(value)
190}
191
192fn add_missing_path_parameters(parameters: &mut Vec<Value>, path: &str) {
193 for name in path
194 .split('/')
195 .filter_map(|part| part.strip_prefix(':'))
196 .filter(|name| !name.is_empty())
197 {
198 let exists = parameters.iter().any(|parameter| {
199 parameter.get("name").and_then(Value::as_str) == Some(name)
200 && parameter.get("in").and_then(Value::as_str) == Some("path")
201 });
202 if !exists {
203 parameters.push(path_param(name, &format!("Path parameter `{name}`")));
204 }
205 }
206}
207
208fn humanize_operation_id(operation_id: &str) -> String {
209 let mut words = Vec::new();
210 let mut current = String::new();
211 for character in operation_id.chars() {
212 if character == '_' || character == '-' {
213 if !current.is_empty() {
214 words.push(std::mem::take(&mut current));
215 }
216 continue;
217 }
218 if character.is_uppercase() && !current.is_empty() {
219 words.push(std::mem::take(&mut current));
220 }
221 current.push(character.to_ascii_lowercase());
222 }
223 if !current.is_empty() {
224 words.push(current);
225 }
226
227 let mut summary = words.join(" ");
228 if let Some(first) = summary.get_mut(0..1) {
229 first.make_ascii_uppercase();
230 }
231 summary
232}
233
234fn tag_for_endpoint(endpoint: &AsyncAuthEndpoint, operation_id: Option<&str>) -> String {
235 if let Some(tag) = tag_for_operation_id(operation_id.unwrap_or_default()) {
236 return tag.to_owned();
237 }
238 let first_segment = endpoint
239 .path
240 .split('/')
241 .find(|segment| !segment.is_empty())
242 .unwrap_or_default();
243 tag_for_path_segment(first_segment)
244 .unwrap_or("Default")
245 .to_owned()
246}
247
248fn tag_for_operation_id(operation_id: &str) -> Option<&'static str> {
249 if operation_id.starts_with("mcp") || operation_id.starts_with("getMcp") {
250 Some("MCP")
251 } else if operation_id.contains("JWT")
252 || operation_id.contains("JSONWeb")
253 || operation_id.ends_with("JWT")
254 {
255 Some("JWT")
256 } else if operation_id.contains("OAuth2") {
257 Some("Generic OAuth")
258 } else if operation_id.contains("Siwe") {
259 Some("SIWE")
260 } else if operation_id.contains("PhoneNumber") {
261 Some("Phone Number")
262 } else if operation_id.contains("TwoFactor")
263 || operation_id.contains("BackupCode")
264 || operation_id.contains("Otp")
265 {
266 Some("Two Factor")
267 } else if operation_id.starts_with("organization") || operation_id.contains("Organization") {
268 Some("Organization")
269 } else {
270 None
271 }
272}
273
274fn tag_for_path_segment(segment: &str) -> Option<&'static str> {
275 match segment {
276 "mcp" => Some("MCP"),
277 "admin" => Some("Admin"),
278 "anonymous" | "delete-anonymous-user" => Some("Anonymous"),
279 "device" | "device-authorization" => Some("Device Authorization"),
280 "email-otp" => Some("Email OTP"),
281 "oauth2" => Some("Generic OAuth"),
282 "jwt" | "jwks" | "token" => Some("JWT"),
283 "magic-link" => Some("Magic Link"),
284 "multi-session" => Some("Multi Session"),
285 "oauth-proxy" => Some("OAuth Proxy"),
286 "one-tap" => Some("One Tap"),
287 "one-time-token" => Some("One Time Token"),
288 "open-api" => Some("Open API"),
289 "organization" => Some("Organization"),
290 "phone-number" => Some("Phone Number"),
291 "siwe" => Some("SIWE"),
292 "two-factor" => Some("Two Factor"),
293 "username" => Some("Username"),
294 _ => None,
295 }
296}
297
298pub fn build_openapi_schema(context: &AuthContext, async_endpoints: &[AsyncAuthEndpoint]) -> Value {
299 let mut paths = serde_json::Map::new();
300 for endpoint in async_endpoints {
301 if endpoint.options.server_only || endpoint.options.hide_from_openapi {
302 continue;
303 }
304 let path = paths
305 .entry(to_openapi_path(&endpoint.path))
306 .or_insert_with(|| Value::Object(serde_json::Map::new()));
307 let Value::Object(methods) = path else {
308 continue;
309 };
310 methods.insert(
311 endpoint.method.as_str().to_ascii_lowercase(),
312 openapi_operation_for_endpoint(endpoint),
313 );
314 }
315 json!({
316 "openapi": "3.1.1",
317 "info": {
318 "title": "RustAuth",
319 "description": "API Reference for your RustAuth instance",
320 "version": crate::VERSION,
321 },
322 "components": {
323 "schemas": openapi_model_schemas(context),
324 "securitySchemes": {
325 "apiKeyCookie": {
326 "type": "apiKey",
327 "in": "cookie",
328 "name": "apiKeyCookie",
329 "description": "API Key authentication via cookie",
330 },
331 "bearerAuth": {
332 "type": "http",
333 "scheme": "bearer",
334 "description": "Bearer token authentication",
335 },
336 },
337 },
338 "security": [
339 {
340 "apiKeyCookie": [],
341 "bearerAuth": [],
342 },
343 ],
344 "servers": [
345 {
346 "url": context.base_url,
347 },
348 ],
349 "tags": [
350 {
351 "name": "Default",
352 "description": "Default endpoints that are included with RustAuth by default. These endpoints are not part of any plugin.",
353 },
354 ],
355 "paths": paths,
356 })
357}
358
359fn method_uses_request_body(method: &Method) -> bool {
360 matches!(*method, Method::POST | Method::PATCH | Method::PUT)
361}
362
363pub(super) fn to_openapi_path(path: &str) -> String {
364 path.split('/')
365 .map(|part| {
366 part.strip_prefix(':')
367 .map(|name| format!("{{{name}}}"))
368 .unwrap_or_else(|| part.to_owned())
369 })
370 .collect::<Vec<_>>()
371 .join("/")
372}
373
374fn default_openapi_responses() -> serde_json::Map<String, Value> {
375 let mut responses = serde_json::Map::new();
376 responses.insert(
377 "400".to_owned(),
378 openapi_error_response(
379 "Bad Request. Usually due to missing parameters, or invalid parameters.",
380 true,
381 ),
382 );
383 responses.insert(
384 "401".to_owned(),
385 openapi_error_response(
386 "Unauthorized. Due to missing or invalid authentication.",
387 true,
388 ),
389 );
390 responses.insert(
391 "403".to_owned(),
392 openapi_error_response(
393 "Forbidden. You do not have permission to access this resource or to perform this action.",
394 false,
395 ),
396 );
397 responses.insert(
398 "404".to_owned(),
399 openapi_error_response("Not Found. The requested resource was not found.", false),
400 );
401 responses.insert(
402 "429".to_owned(),
403 openapi_error_response(
404 "Too Many Requests. You have exceeded the rate limit. Try again later.",
405 false,
406 ),
407 );
408 responses.insert(
409 "500".to_owned(),
410 openapi_error_response(
411 "Internal Server Error. This is a problem with the server that you cannot fix.",
412 false,
413 ),
414 );
415 responses
416}
417
418fn openapi_error_response(description: &str, require_message: bool) -> Value {
419 let mut required = vec!["code"];
420 if require_message {
421 required.push("message");
422 }
423 let mut schema = serde_json::Map::new();
424 schema.insert("type".to_owned(), Value::String("object".to_owned()));
425 schema.insert(
426 "properties".to_owned(),
427 json!({
428 "code": {
429 "type": "string",
430 },
431 "message": {
432 "type": "string",
433 },
434 "originalMessage": {
435 "type": "string",
436 },
437 }),
438 );
439 schema.insert("required".to_owned(), json!(required));
440 json!({
441 "content": {
442 "application/json": {
443 "schema": Value::Object(schema),
444 },
445 },
446 "description": description,
447 })
448}
449
450pub fn json_openapi_response(description: &str, schema: Value) -> Value {
451 json!({
452 "description": description,
453 "content": {
454 "application/json": {
455 "schema": schema,
456 },
457 },
458 })
459}
460
461pub fn empty_openapi_response(description: &str) -> Value {
462 json!({
463 "description": description,
464 })
465}
466
467pub fn redirect_openapi_response(description: &str) -> Value {
468 json!({
469 "description": description,
470 "headers": {
471 "Location": {
472 "description": "Redirect target",
473 "schema": {
474 "type": "string",
475 "format": "uri",
476 },
477 },
478 },
479 })
480}
481
482pub fn query_param(name: &str, description: &str) -> Value {
483 json!({
484 "name": name,
485 "in": "query",
486 "required": false,
487 "description": description,
488 "schema": {
489 "type": "string",
490 },
491 })
492}
493
494pub fn path_param(name: &str, description: &str) -> Value {
495 json!({
496 "name": name,
497 "in": "path",
498 "required": true,
499 "description": description,
500 "schema": {
501 "type": "string",
502 },
503 })
504}
505
506pub(super) fn openapi_model_schemas(context: &AuthContext) -> Value {
507 let mut schemas = serde_json::Map::new();
508 for (logical_table, table) in context.db_schema.tables() {
509 let mut properties = serde_json::Map::new();
510 let mut required = Vec::new();
511 for (logical_field, field) in &table.fields {
512 let property_name = openapi_property_name(logical_field);
513 if field.required {
514 required.push(Value::String(property_name.clone()));
515 }
516 properties.insert(
517 property_name,
518 openapi_field_schema(context, logical_table, logical_field, field),
519 );
520 }
521 match logical_table {
522 "user" => append_runtime_additional_fields(
523 context,
524 logical_table,
525 &mut properties,
526 &mut required,
527 &context.options.user.additional_fields,
528 ),
529 "session" => append_runtime_additional_fields(
530 context,
531 logical_table,
532 &mut properties,
533 &mut required,
534 &context.options.session.additional_fields,
535 ),
536 _ => {}
537 }
538
539 schemas.insert(
540 openapi_schema_name(logical_table),
541 json!({
542 "type": "object",
543 "properties": properties,
544 "required": required,
545 "additionalProperties": true,
546 }),
547 );
548 }
549 Value::Object(schemas)
550}
551
552fn append_runtime_additional_fields<F>(
553 context: &AuthContext,
554 logical_table: &str,
555 properties: &mut serde_json::Map<String, Value>,
556 required: &mut Vec<Value>,
557 fields: &std::collections::BTreeMap<String, F>,
558) where
559 F: RuntimeAdditionalField,
560{
561 for (logical_field, field) in fields {
562 let property_name = openapi_property_name(logical_field);
563 if properties.contains_key(&property_name) {
564 continue;
565 }
566 let db_field = DbField {
567 name: field
568 .db_name()
569 .map(str::to_owned)
570 .unwrap_or_else(|| logical_field.clone()),
571 field_type: field.field_type().clone(),
572 required: field.required(),
573 unique: false,
574 index: false,
575 returned: field.returned(),
576 input: field.input(),
577 foreign_key: None,
578 generated_id: None,
579 };
580 if db_field.required {
581 required.push(Value::String(property_name.clone()));
582 }
583 properties.insert(
584 property_name,
585 openapi_field_schema(context, logical_table, logical_field, &db_field),
586 );
587 }
588}
589
590fn openapi_field_schema(
591 context: &AuthContext,
592 logical_table: &str,
593 logical_field: &str,
594 field: &DbField,
595) -> Value {
596 let mut schema = serde_json::Map::new();
597 let type_name = openapi_field_type(&field.field_type);
598 if field.required {
599 schema.insert("type".to_owned(), Value::String(type_name.to_owned()));
600 } else {
601 schema.insert("type".to_owned(), json!([type_name, "null"]));
602 }
603 match field.field_type {
604 DbFieldType::String => {
605 if logical_field == "email" {
606 schema.insert("format".to_owned(), Value::String("email".to_owned()));
607 } else if logical_field == "image" || logical_field == "logo" {
608 schema.insert("format".to_owned(), Value::String("uri".to_owned()));
609 }
610 }
611 DbFieldType::Timestamp => {
612 schema.insert("format".to_owned(), Value::String("date-time".to_owned()));
613 }
614 DbFieldType::StringArray => {
615 schema.insert("items".to_owned(), json!({ "type": "string" }));
616 }
617 DbFieldType::NumberArray => {
618 schema.insert("items".to_owned(), json!({ "type": "number" }));
619 }
620 DbFieldType::Number | DbFieldType::Boolean | DbFieldType::Json => {}
621 }
622 if !field.input {
623 schema.insert("readOnly".to_owned(), Value::Bool(true));
624 }
625 if let Some(default_value) = openapi_field_default(context, logical_table, logical_field) {
626 schema.insert("default".to_owned(), default_value);
627 }
628 Value::Object(schema)
629}
630
631fn openapi_field_type(field_type: &DbFieldType) -> &'static str {
632 match field_type {
633 DbFieldType::String | DbFieldType::Timestamp => "string",
634 DbFieldType::Number => "number",
635 DbFieldType::Boolean => "boolean",
636 DbFieldType::Json => "object",
637 DbFieldType::StringArray | DbFieldType::NumberArray => "array",
638 }
639}
640
641fn openapi_field_default(
642 context: &AuthContext,
643 logical_table: &str,
644 logical_field: &str,
645) -> Option<Value> {
646 let value = match logical_table {
647 "user" => context
648 .options
649 .user
650 .additional_fields
651 .get(logical_field)
652 .and_then(|field| field.default_value.as_ref()),
653 "session" => context
654 .options
655 .session
656 .additional_fields
657 .get(logical_field)
658 .and_then(|field| field.default_value.as_ref()),
659 _ => None,
660 }?;
661 db_value_to_openapi_default(value)
662}
663
664fn db_value_to_openapi_default(value: &DbValue) -> Option<Value> {
665 match value {
666 DbValue::String(value) => Some(Value::String(value.clone())),
667 DbValue::Number(value) => Some(Value::Number((*value).into())),
668 DbValue::Boolean(value) => Some(Value::Bool(*value)),
669 DbValue::Json(value) => Some(value.clone()),
670 DbValue::StringArray(values) => Some(Value::Array(
671 values.iter().cloned().map(Value::String).collect(),
672 )),
673 DbValue::NumberArray(values) => Some(Value::Array(
674 values
675 .iter()
676 .map(|value| Value::Number((*value).into()))
677 .collect(),
678 )),
679 DbValue::Null => Some(Value::Null),
680 DbValue::Timestamp(_) | DbValue::Record(_) | DbValue::RecordArray(_) => None,
681 }
682}
683
684fn openapi_schema_name(logical_table: &str) -> String {
685 match logical_table {
686 "user" => "User".to_owned(),
687 "session" => "Session".to_owned(),
688 "account" => "Account".to_owned(),
689 "verification" => "Verification".to_owned(),
690 "rate_limit" => "RateLimit".to_owned(),
691 "organization" => "Organization".to_owned(),
692 "member" => "Member".to_owned(),
693 "invitation" => "Invitation".to_owned(),
694 "team" => "Team".to_owned(),
695 "team_member" => "TeamMember".to_owned(),
696 "organization_role" => "OrganizationRole".to_owned(),
697 "wallet_address" => "WalletAddress".to_owned(),
698 value => pascal_case(value),
699 }
700}
701
702fn openapi_property_name(logical_field: &str) -> String {
703 snake_to_camel(logical_field)
704}
705
706fn snake_to_camel(value: &str) -> String {
707 let mut output = String::new();
708 let mut uppercase_next = false;
709 for character in value.chars() {
710 if character == '_' {
711 uppercase_next = true;
712 continue;
713 }
714 if uppercase_next {
715 output.extend(character.to_uppercase());
716 uppercase_next = false;
717 } else {
718 output.push(character);
719 }
720 }
721 output
722}
723
724fn pascal_case(value: &str) -> String {
725 let mut output = String::new();
726 let mut capitalize = true;
727 for character in value.chars() {
728 if matches!(character, '_' | '-' | ' ') {
729 capitalize = true;
730 continue;
731 }
732 if capitalize {
733 output.extend(character.to_uppercase());
734 capitalize = false;
735 } else {
736 output.push(character);
737 }
738 }
739 output
740}