1use std::sync::Arc;
4
5use bytes::Bytes;
6use serde_json::{json, Map, Value};
7
8use tork_core::constants::APPLICATION_JSON;
9use tork_core::{
10 bytes_response, BoxFuture, HandlerFn, Method, OpenApiProvider, RequestBodyKind, RequestContext,
11 Response, Result, Route, StatusCode,
12};
13
14const OPENAPI_VERSION: &str = "3.1.0";
16const DEFAULT_JSON_PATH: &str = "/openapi.json";
18
19pub(crate) type DocGuard = Arc<dyn Fn(&RequestContext) -> bool + Send + Sync>;
22
23pub struct OpenApi {
29 title: String,
30 version: String,
31 description: Option<String>,
32 json_path: String,
33 docs_path: Option<String>,
34 guard: Option<DocGuard>,
35}
36
37impl Default for OpenApi {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl OpenApi {
44 pub fn new() -> Self {
46 Self {
47 title: "API".to_owned(),
48 version: "0.1.0".to_owned(),
49 description: None,
50 json_path: DEFAULT_JSON_PATH.to_owned(),
51 docs_path: None,
52 guard: None,
53 }
54 }
55
56 pub fn title(mut self, title: impl Into<String>) -> Self {
58 self.title = title.into();
59 self
60 }
61
62 pub fn version(mut self, version: impl Into<String>) -> Self {
64 self.version = version.into();
65 self
66 }
67
68 pub fn description(mut self, description: impl Into<String>) -> Self {
70 self.description = Some(description.into());
71 self
72 }
73
74 pub fn json(mut self, path: impl Into<String>) -> Self {
76 self.json_path = path.into();
77 self
78 }
79
80 pub fn docs(mut self, path: impl Into<String>) -> Self {
82 self.docs_path = Some(path.into());
83 self
84 }
85
86 pub fn protect<F>(mut self, predicate: F) -> Self
110 where
111 F: Fn(&RequestContext) -> bool + Send + Sync + 'static,
112 {
113 self.guard = Some(Arc::new(predicate));
114 self
115 }
116
117 pub fn build_document(&self, routes: &[Route]) -> Value {
119 build_document(self, routes)
120 }
121}
122
123impl OpenApiProvider for OpenApi {
124 fn documentation_routes(&self, registered: &[Route]) -> Vec<Route> {
125 let document = build_document(self, registered);
126 let body = serde_json::to_vec(&document).unwrap_or_default();
127
128 let mut routes = vec![spec_route(
129 &self.json_path,
130 Bytes::from(body),
131 self.guard.clone(),
132 )];
133 if let Some(docs_path) = &self.docs_path {
134 routes.push(crate::docs::docs_route(
135 docs_path,
136 &self.title,
137 &self.json_path,
138 self.guard.clone(),
139 ));
140 }
141 routes
142 }
143}
144
145pub(crate) fn check_guard(guard: &Option<DocGuard>, ctx: &RequestContext) -> Result<()> {
148 match guard {
149 Some(guard) if !guard(ctx) => Err(tork_core::Error::not_found("not found")),
150 _ => Ok(()),
151 }
152}
153
154fn spec_route(path: &str, body: Bytes, guard: Option<DocGuard>) -> Route {
156 let handler: HandlerFn = Arc::new(
157 move |ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
158 let body = body.clone();
159 let guard = guard.clone();
160 Box::pin(async move {
161 check_guard(&guard, &ctx)?;
162 Ok(bytes_response(StatusCode::OK, APPLICATION_JSON, body))
163 })
164 },
165 );
166
167 Route::new(Method::GET, path.to_owned(), handler).summary("OpenAPI specification")
168}
169
170fn build_document(api: &OpenApi, routes: &[Route]) -> Value {
172 let mut generator = schemars::generate::SchemaSettings::openapi3().into_generator();
175 let mut paths: Map<String, Value> = Map::new();
176
177 for route in routes {
178 let path = route.path().to_owned();
179 let method = route.method().as_str().to_lowercase();
180 let meta = route.meta();
181
182 let mut operation = Map::new();
183 if let Some(summary) = &meta.summary {
184 operation.insert("summary".to_owned(), json!(sanitize_doc_text(summary)));
185 }
186 if let Some(description) = &meta.description {
187 operation.insert(
188 "description".to_owned(),
189 json!(sanitize_doc_text(description)),
190 );
191 }
192 if !meta.tags.is_empty() {
193 let tags: Vec<String> = meta.tags.iter().map(|tag| sanitize_doc_text(tag)).collect();
194 operation.insert("tags".to_owned(), json!(tags));
195 }
196 operation.insert(
197 "operationId".to_owned(),
198 json!(operation_id(&method, &path)),
199 );
200
201 let parameters: Vec<Value> = placeholder_names(&path)
202 .into_iter()
203 .map(|name| {
204 json!({
205 "name": name,
206 "in": "path",
207 "required": true,
208 "schema": { "type": "string" },
209 })
210 })
211 .collect();
212 if !parameters.is_empty() {
213 operation.insert("parameters".to_owned(), json!(parameters));
214 }
215
216 if let Some(request_schema) = meta.request_schema {
217 let schema = request_schema(&mut generator).as_value().clone();
218 let media_type = match meta.request_kind {
222 RequestBodyKind::Json => "application/json",
223 RequestBodyKind::Form => "application/x-www-form-urlencoded",
224 RequestBodyKind::Multipart => "multipart/form-data",
225 };
226 operation.insert(
227 "requestBody".to_owned(),
228 json!({
229 "required": true,
230 "content": { media_type: { "schema": schema } },
231 }),
232 );
233 }
234
235 let status = meta.status_code.as_u16().to_string();
236 let mut response = Map::new();
237 let schema = meta
238 .response_schema
239 .map(|thunk| thunk(&mut generator).as_value().clone());
240 if meta.streaming {
241 response.insert("description".to_owned(), json!("Server-Sent Events stream"));
244 if let Some(schema) = schema {
245 response.insert(
246 "content".to_owned(),
247 json!({ "text/event-stream": { "schema": schema } }),
248 );
249 }
250 } else {
251 let reason = meta.status_code.canonical_reason().unwrap_or("Response");
252 response.insert("description".to_owned(), json!(reason));
253 if let Some(schema) = schema {
254 response.insert(
255 "content".to_owned(),
256 json!({ "application/json": { "schema": schema } }),
257 );
258 }
259 }
260 operation.insert(
261 "responses".to_owned(),
262 json!({ status: Value::Object(response) }),
263 );
264
265 let entry = paths
266 .entry(path)
267 .or_insert_with(|| Value::Object(Map::new()));
268 if let Some(object) = entry.as_object_mut() {
269 object.insert(method, Value::Object(operation));
270 }
271 }
272
273 let mut info = Map::new();
274 info.insert("title".to_owned(), json!(sanitize_doc_text(&api.title)));
275 info.insert("version".to_owned(), json!(api.version));
276 if let Some(description) = &api.description {
277 info.insert(
278 "description".to_owned(),
279 json!(sanitize_doc_text(description)),
280 );
281 }
282
283 let mut document = json!({
284 "openapi": OPENAPI_VERSION,
285 "info": Value::Object(info),
286 "paths": Value::Object(paths),
287 });
288
289 let definitions = generator.take_definitions(true);
291 if !definitions.is_empty() {
292 document["components"] = json!({ "schemas": Value::Object(definitions) });
293 }
294
295 document
296}
297
298pub(crate) fn sanitize_doc_text(value: &str) -> String {
299 let mut sanitized = String::with_capacity(value.len());
300 for ch in value.chars() {
301 match ch {
302 '&' => sanitized.push_str("&"),
303 '<' => sanitized.push_str("<"),
304 '>' => sanitized.push_str(">"),
305 '"' => sanitized.push_str("""),
306 '\'' => sanitized.push_str("'"),
307 '`' => sanitized.push_str("`"),
308 '\n' | '\r' | '\t' => sanitized.push(ch),
309 ch if ch.is_control() => sanitized.push(' '),
310 _ => sanitized.push(ch),
311 }
312 }
313 sanitized
314}
315
316fn operation_id(method: &str, path: &str) -> String {
318 let mut id = String::from(method);
319 for segment in path.split('/').filter(|segment| !segment.is_empty()) {
320 id.push('_');
321 for ch in segment.chars() {
322 id.push(if ch.is_ascii_alphanumeric() { ch } else { '_' });
323 }
324 }
325 id
326}
327
328fn placeholder_names(path: &str) -> Vec<String> {
330 let mut names = Vec::new();
331 let bytes = path.as_bytes();
332 let mut index = 0;
333
334 while index < bytes.len() {
335 if bytes[index] == b'{' {
336 if let Some(offset) = path[index + 1..].find('}') {
337 let inner = &path[index + 1..index + 1 + offset];
338 names.push(inner.trim_start_matches('*').to_owned());
339 index += offset + 2;
340 continue;
341 }
342 }
343 index += 1;
344 }
345
346 names
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn dummy_handler() -> HandlerFn {
354 Arc::new(
355 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
356 Box::pin(async {
357 Ok(bytes_response(
358 StatusCode::OK,
359 APPLICATION_JSON,
360 Bytes::new(),
361 ))
362 })
363 },
364 )
365 }
366
367 #[test]
368 fn document_describes_routes() {
369 let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
370 .summary("Get user")
371 .tag("users")];
372
373 let document = OpenApi::new()
374 .title("My API")
375 .version("1.0.0")
376 .build_document(&routes);
377
378 assert_eq!(document["openapi"], OPENAPI_VERSION);
379 assert_eq!(document["info"]["title"], "My API");
380 assert_eq!(document["info"]["version"], "1.0.0");
381
382 let operation = &document["paths"]["/users/{user_id}"]["get"];
383 assert_eq!(operation["summary"], "Get user");
384 assert_eq!(operation["tags"][0], "users");
385 assert_eq!(operation["parameters"][0]["name"], "user_id");
386 assert_eq!(operation["parameters"][0]["in"], "path");
387 assert!(operation["responses"]["200"].is_object());
388 }
389
390 #[derive(schemars::JsonSchema)]
391 #[allow(dead_code)]
392 struct Sample {
393 id: i64,
394 label: String,
395 }
396
397 #[derive(schemars::JsonSchema)]
398 #[allow(dead_code)]
399 struct Inner {
400 value: String,
401 }
402
403 #[derive(schemars::JsonSchema)]
404 #[allow(dead_code)]
405 struct Outer {
406 inner: Inner,
407 }
408
409 #[test]
410 fn nested_models_are_registered_as_components() {
411 let routes =
412 vec![Route::new(Method::GET, "/outer", dummy_handler()).response_schema::<Outer>()];
413
414 let schemas = &OpenApi::new().build_document(&routes)["components"]["schemas"];
415 assert!(schemas["Outer"].is_object(), "outer missing: {schemas}");
416 assert!(
417 schemas["Inner"].is_object(),
418 "nested inner missing: {schemas}"
419 );
420 }
421
422 #[test]
423 fn document_includes_component_schemas() {
424 let routes = vec![Route::new(Method::POST, "/samples", dummy_handler())
425 .request_schema::<Sample>()
426 .response_schema::<Sample>()];
427
428 let document = OpenApi::new().build_document(&routes);
429
430 assert!(
432 document["components"]["schemas"]["Sample"].is_object(),
433 "document: {document}"
434 );
435
436 let operation = &document["paths"]["/samples"]["post"];
437 let request_ref =
438 &operation["requestBody"]["content"]["application/json"]["schema"]["$ref"];
439 let response_ref =
440 &operation["responses"]["200"]["content"]["application/json"]["schema"]["$ref"];
441 assert_eq!(request_ref, "#/components/schemas/Sample");
442 assert_eq!(response_ref, "#/components/schemas/Sample");
443 }
444
445 #[test]
446 fn multipart_route_documents_form_data_with_binary_file() {
447 fn form_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
450 schemars::Schema::try_from(json!({
451 "type": "object",
452 "properties": {
453 "token": { "type": "string" },
454 "file": { "type": "string", "format": "binary" },
455 },
456 "required": ["token", "file"],
457 }))
458 .unwrap()
459 }
460
461 let routes = vec![Route::new(Method::POST, "/files", dummy_handler())
462 .request_schema_fn(form_schema)
463 .request_kind(RequestBodyKind::Multipart)];
464
465 let document = OpenApi::new().build_document(&routes);
466 let content = &document["paths"]["/files"]["post"]["requestBody"]["content"];
467
468 let schema = &content["multipart/form-data"]["schema"];
469 assert_eq!(schema["properties"]["file"]["format"], "binary");
470 assert!(
471 content["application/json"].is_null(),
472 "multipart body must not be JSON: {content}"
473 );
474 }
475
476 #[test]
477 fn urlencoded_route_documents_form_content_type() {
478 let routes = vec![Route::new(Method::POST, "/login", dummy_handler())
479 .request_schema::<Sample>()
480 .request_kind(RequestBodyKind::Form)];
481
482 let document = OpenApi::new().build_document(&routes);
483 let content = &document["paths"]["/login"]["post"]["requestBody"]["content"];
484
485 assert!(
486 content["application/x-www-form-urlencoded"]["schema"].is_object(),
487 "expected urlencoded body: {content}"
488 );
489 assert!(content["application/json"].is_null());
490 }
491
492 #[test]
493 fn streaming_route_documents_event_stream() {
494 let routes = vec![Route::new(Method::GET, "/stream", dummy_handler())
495 .response_schema::<Sample>()
496 .streaming()];
497
498 let document = OpenApi::new().build_document(&routes);
499 let response = &document["paths"]["/stream"]["get"]["responses"]["200"];
500
501 assert_eq!(response["description"], "Server-Sent Events stream");
502 assert_eq!(
503 response["content"]["text/event-stream"]["schema"]["$ref"],
504 "#/components/schemas/Sample"
505 );
506 assert!(
507 response["content"]["application/json"].is_null(),
508 "streaming response must not be JSON: {response}"
509 );
510 }
511
512 #[test]
513 fn provider_registers_spec_and_docs_routes() {
514 let provider = OpenApi::new()
515 .title("Docs")
516 .version("1.2.3")
517 .json("/schema.json")
518 .docs("/docs");
519
520 let routes = provider.documentation_routes(&[]);
521
522 assert_eq!(routes.len(), 2);
523 assert_eq!(routes[0].path(), "/schema.json");
524 assert_eq!(routes[1].path(), "/docs");
525 }
526
527 #[test]
528 fn operation_id_and_placeholder_helpers_cover_edge_cases() {
529 assert_eq!(operation_id("patch", "/"), "patch");
530 assert_eq!(
531 operation_id("get", "/teams/{team-id}/members/{*rest}"),
532 "get_teams__team_id__members___rest_"
533 );
534 assert_eq!(
535 placeholder_names("/teams/{team_id}/members/{*rest}"),
536 vec!["team_id".to_owned(), "rest".to_owned()]
537 );
538 }
539
540 #[test]
541 fn document_sanitizes_route_and_info_text_fields() {
542 let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
543 .summary("<script>alert(1)</script>")
544 .description("bad\u{0007}`quote`")
545 .tag("ops<script>")];
546
547 let document = OpenApi::new()
548 .title("Docs <unsafe>")
549 .description("line\u{0001}two")
550 .build_document(&routes);
551
552 let operation = &document["paths"]["/users/{user_id}"]["get"];
553 assert_eq!(
554 operation["summary"],
555 "<script>alert(1)</script>"
556 );
557 assert_eq!(operation["description"], "bad `quote`");
558 assert_eq!(operation["tags"][0], "ops<script>");
559 assert_eq!(document["info"]["title"], "Docs <unsafe>");
560 assert_eq!(document["info"]["description"], "line two");
561 }
562}