Skip to main content

tork_openapi/
spec.rs

1//! The `OpenApi` builder and specification document assembly.
2
3use std::sync::Arc;
4
5use bytes::Bytes;
6use serde_json::{json, Map, Value};
7
8use tork_core::constants::APPLICATION_JSON;
9use tork_core::{
10    bytes_response, BoxFuture, HandlerFn, Method, OpenApiProvider, RequestBodyKind, RequestContext,
11    Response, Result, Route, StatusCode,
12};
13
14/// OpenAPI specification version emitted by the document.
15const OPENAPI_VERSION: &str = "3.1.0";
16/// Default path at which the specification document is served.
17const DEFAULT_JSON_PATH: &str = "/openapi.json";
18
19/// A predicate gating access to the documentation routes. Returning `false`
20/// hides the spec and docs UI behind a `404`.
21pub(crate) type DocGuard = Arc<dyn Fn(&RequestContext) -> bool + Send + Sync>;
22
23/// Configures OpenAPI document generation.
24///
25/// The document describes paths, methods, summaries, descriptions, tags, path
26/// parameters, and — for routes whose handlers use `#[api_model]` bodies and
27/// return types — request and response body schemas under `components.schemas`.
28pub struct OpenApi {
29    title: String,
30    version: String,
31    description: Option<String>,
32    json_path: String,
33    docs_path: Option<String>,
34    guard: Option<DocGuard>,
35}
36
37impl Default for OpenApi {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl OpenApi {
44    /// Creates a builder with default title, version, and document path.
45    pub fn new() -> Self {
46        Self {
47            title: "API".to_owned(),
48            version: "0.1.0".to_owned(),
49            description: None,
50            json_path: DEFAULT_JSON_PATH.to_owned(),
51            docs_path: None,
52            guard: None,
53        }
54    }
55
56    /// Sets the API title.
57    pub fn title(mut self, title: impl Into<String>) -> Self {
58        self.title = title.into();
59        self
60    }
61
62    /// Sets the API version.
63    pub fn version(mut self, version: impl Into<String>) -> Self {
64        self.version = version.into();
65        self
66    }
67
68    /// Sets the API description.
69    pub fn description(mut self, description: impl Into<String>) -> Self {
70        self.description = Some(description.into());
71        self
72    }
73
74    /// Sets the path at which the specification document is served.
75    pub fn json(mut self, path: impl Into<String>) -> Self {
76        self.json_path = path.into();
77        self
78    }
79
80    /// Enables the documentation UI, served at `path`.
81    pub fn docs(mut self, path: impl Into<String>) -> Self {
82        self.docs_path = Some(path.into());
83        self
84    }
85
86    /// Restricts access to the spec and docs routes to requests the predicate
87    /// accepts; rejected requests get a `404` (hiding that the routes exist).
88    ///
89    /// Use this to keep the API surface from being publicly discoverable — for
90    /// example, gate it on a bearer token, an internal network, or an environment
91    /// flag. The predicate runs on every request to the documentation routes.
92    ///
93    /// Compare the credential with [`constant_time_eq`](tork_core::security::constant_time_eq)
94    /// rather than `==`, so the check does not leak how many bytes matched via its
95    /// timing:
96    ///
97    /// ```
98    /// # use tork_openapi::OpenApi;
99    /// use tork_core::security::constant_time_eq;
100    /// let api = OpenApi::new().docs("/docs").protect(|ctx| {
101    ///     ctx.headers()
102    ///         .get("authorization")
103    ///         .and_then(|v| v.to_str().ok())
104    ///         .map(|header| constant_time_eq(header, "Bearer secret-docs-token"))
105    ///         .unwrap_or(false)
106    /// });
107    /// # let _ = api;
108    /// ```
109    pub fn protect<F>(mut self, predicate: F) -> Self
110    where
111        F: Fn(&RequestContext) -> bool + Send + Sync + 'static,
112    {
113        self.guard = Some(Arc::new(predicate));
114        self
115    }
116
117    /// Builds the OpenAPI document for the given routes as a JSON value.
118    pub fn build_document(&self, routes: &[Route]) -> Value {
119        build_document(self, routes)
120    }
121}
122
123impl OpenApiProvider for OpenApi {
124    fn documentation_routes(&self, registered: &[Route]) -> Vec<Route> {
125        let document = build_document(self, registered);
126        let body = serde_json::to_vec(&document).unwrap_or_default();
127
128        let mut routes = vec![spec_route(
129            &self.json_path,
130            Bytes::from(body),
131            self.guard.clone(),
132        )];
133        if let Some(docs_path) = &self.docs_path {
134            routes.push(crate::docs::docs_route(
135                docs_path,
136                &self.title,
137                &self.json_path,
138                self.guard.clone(),
139            ));
140        }
141        routes
142    }
143}
144
145/// Rejects a request with `404` when a documentation guard denies it, hiding the
146/// route's existence; returns `Ok(())` when there is no guard or it allows access.
147pub(crate) fn check_guard(guard: &Option<DocGuard>, ctx: &RequestContext) -> Result<()> {
148    match guard {
149        Some(guard) if !guard(ctx) => Err(tork_core::Error::not_found("not found")),
150        _ => Ok(()),
151    }
152}
153
154/// Builds a route that serves a pre-serialized document at `path`.
155fn spec_route(path: &str, body: Bytes, guard: Option<DocGuard>) -> Route {
156    let handler: HandlerFn = Arc::new(
157        move |ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
158            let body = body.clone();
159            let guard = guard.clone();
160            Box::pin(async move {
161                check_guard(&guard, &ctx)?;
162                Ok(bytes_response(StatusCode::OK, APPLICATION_JSON, body))
163            })
164        },
165    );
166
167    Route::new(Method::GET, path.to_owned(), handler).summary("OpenAPI specification")
168}
169
170/// Assembles the OpenAPI document from the route table.
171fn build_document(api: &OpenApi, routes: &[Route]) -> Value {
172    // A single generator collects every model schema; with the OpenAPI 3 settings
173    // its `$ref`s already point at `#/components/schemas/...`.
174    let mut generator = schemars::generate::SchemaSettings::openapi3().into_generator();
175    let mut paths: Map<String, Value> = Map::new();
176
177    for route in routes {
178        let path = route.path().to_owned();
179        let method = route.method().as_str().to_lowercase();
180        let meta = route.meta();
181
182        let mut operation = Map::new();
183        if let Some(summary) = &meta.summary {
184            operation.insert("summary".to_owned(), json!(sanitize_doc_text(summary)));
185        }
186        if let Some(description) = &meta.description {
187            operation.insert(
188                "description".to_owned(),
189                json!(sanitize_doc_text(description)),
190            );
191        }
192        if !meta.tags.is_empty() {
193            let tags: Vec<String> = meta.tags.iter().map(|tag| sanitize_doc_text(tag)).collect();
194            operation.insert("tags".to_owned(), json!(tags));
195        }
196        operation.insert(
197            "operationId".to_owned(),
198            json!(operation_id(&method, &path)),
199        );
200
201        let parameters: Vec<Value> = placeholder_names(&path)
202            .into_iter()
203            .map(|name| {
204                json!({
205                    "name": name,
206                    "in": "path",
207                    "required": true,
208                    "schema": { "type": "string" },
209                })
210            })
211            .collect();
212        if !parameters.is_empty() {
213            operation.insert("parameters".to_owned(), json!(parameters));
214        }
215
216        if let Some(request_schema) = meta.request_schema {
217            let schema = request_schema(&mut generator).as_value().clone();
218            // The media type follows the declared body encoding: JSON bodies,
219            // urlencoded forms, or multipart forms (whose file fields are marked
220            // `format: binary` in the schema).
221            let media_type = match meta.request_kind {
222                RequestBodyKind::Json => "application/json",
223                RequestBodyKind::Form => "application/x-www-form-urlencoded",
224                RequestBodyKind::Multipart => "multipart/form-data",
225            };
226            operation.insert(
227                "requestBody".to_owned(),
228                json!({
229                    "required": true,
230                    "content": { media_type: { "schema": schema } },
231                }),
232            );
233        }
234
235        let status = meta.status_code.as_u16().to_string();
236        let mut response = Map::new();
237        let schema = meta
238            .response_schema
239            .map(|thunk| thunk(&mut generator).as_value().clone());
240        if meta.streaming {
241            // A Server-Sent Events stream: each message carries a JSON-encoded
242            // value of this schema in its `data:` field.
243            response.insert("description".to_owned(), json!("Server-Sent Events stream"));
244            if let Some(schema) = schema {
245                response.insert(
246                    "content".to_owned(),
247                    json!({ "text/event-stream": { "schema": schema } }),
248                );
249            }
250        } else {
251            let reason = meta.status_code.canonical_reason().unwrap_or("Response");
252            response.insert("description".to_owned(), json!(reason));
253            if let Some(schema) = schema {
254                response.insert(
255                    "content".to_owned(),
256                    json!({ "application/json": { "schema": schema } }),
257                );
258            }
259        }
260        operation.insert(
261            "responses".to_owned(),
262            json!({ status: Value::Object(response) }),
263        );
264
265        let entry = paths
266            .entry(path)
267            .or_insert_with(|| Value::Object(Map::new()));
268        if let Some(object) = entry.as_object_mut() {
269            object.insert(method, Value::Object(operation));
270        }
271    }
272
273    let mut info = Map::new();
274    info.insert("title".to_owned(), json!(sanitize_doc_text(&api.title)));
275    info.insert("version".to_owned(), json!(api.version));
276    if let Some(description) = &api.description {
277        info.insert(
278            "description".to_owned(),
279            json!(sanitize_doc_text(description)),
280        );
281    }
282
283    let mut document = json!({
284        "openapi": OPENAPI_VERSION,
285        "info": Value::Object(info),
286        "paths": Value::Object(paths),
287    });
288
289    // Emit every collected model schema under components.schemas.
290    let definitions = generator.take_definitions(true);
291    if !definitions.is_empty() {
292        document["components"] = json!({ "schemas": Value::Object(definitions) });
293    }
294
295    document
296}
297
298pub(crate) fn sanitize_doc_text(value: &str) -> String {
299    let mut sanitized = String::with_capacity(value.len());
300    for ch in value.chars() {
301        match ch {
302            '&' => sanitized.push_str("&amp;"),
303            '<' => sanitized.push_str("&lt;"),
304            '>' => sanitized.push_str("&gt;"),
305            '"' => sanitized.push_str("&quot;"),
306            '\'' => sanitized.push_str("&#x27;"),
307            '`' => sanitized.push_str("&#x60;"),
308            '\n' | '\r' | '\t' => sanitized.push(ch),
309            ch if ch.is_control() => sanitized.push(' '),
310            _ => sanitized.push(ch),
311        }
312    }
313    sanitized
314}
315
316/// Derives a stable `operationId` from the method and path.
317fn operation_id(method: &str, path: &str) -> String {
318    let mut id = String::from(method);
319    for segment in path.split('/').filter(|segment| !segment.is_empty()) {
320        id.push('_');
321        for ch in segment.chars() {
322            id.push(if ch.is_ascii_alphanumeric() { ch } else { '_' });
323        }
324    }
325    id
326}
327
328/// Extracts the placeholder names from a path, e.g. `["user_id"]`.
329fn placeholder_names(path: &str) -> Vec<String> {
330    let mut names = Vec::new();
331    let bytes = path.as_bytes();
332    let mut index = 0;
333
334    while index < bytes.len() {
335        if bytes[index] == b'{' {
336            if let Some(offset) = path[index + 1..].find('}') {
337                let inner = &path[index + 1..index + 1 + offset];
338                names.push(inner.trim_start_matches('*').to_owned());
339                index += offset + 2;
340                continue;
341            }
342        }
343        index += 1;
344    }
345
346    names
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    fn dummy_handler() -> HandlerFn {
354        Arc::new(
355            |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
356                Box::pin(async {
357                    Ok(bytes_response(
358                        StatusCode::OK,
359                        APPLICATION_JSON,
360                        Bytes::new(),
361                    ))
362                })
363            },
364        )
365    }
366
367    #[test]
368    fn document_describes_routes() {
369        let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
370            .summary("Get user")
371            .tag("users")];
372
373        let document = OpenApi::new()
374            .title("My API")
375            .version("1.0.0")
376            .build_document(&routes);
377
378        assert_eq!(document["openapi"], OPENAPI_VERSION);
379        assert_eq!(document["info"]["title"], "My API");
380        assert_eq!(document["info"]["version"], "1.0.0");
381
382        let operation = &document["paths"]["/users/{user_id}"]["get"];
383        assert_eq!(operation["summary"], "Get user");
384        assert_eq!(operation["tags"][0], "users");
385        assert_eq!(operation["parameters"][0]["name"], "user_id");
386        assert_eq!(operation["parameters"][0]["in"], "path");
387        assert!(operation["responses"]["200"].is_object());
388    }
389
390    #[derive(schemars::JsonSchema)]
391    #[allow(dead_code)]
392    struct Sample {
393        id: i64,
394        label: String,
395    }
396
397    #[derive(schemars::JsonSchema)]
398    #[allow(dead_code)]
399    struct Inner {
400        value: String,
401    }
402
403    #[derive(schemars::JsonSchema)]
404    #[allow(dead_code)]
405    struct Outer {
406        inner: Inner,
407    }
408
409    #[test]
410    fn nested_models_are_registered_as_components() {
411        let routes =
412            vec![Route::new(Method::GET, "/outer", dummy_handler()).response_schema::<Outer>()];
413
414        let schemas = &OpenApi::new().build_document(&routes)["components"]["schemas"];
415        assert!(schemas["Outer"].is_object(), "outer missing: {schemas}");
416        assert!(
417            schemas["Inner"].is_object(),
418            "nested inner missing: {schemas}"
419        );
420    }
421
422    #[test]
423    fn document_includes_component_schemas() {
424        let routes = vec![Route::new(Method::POST, "/samples", dummy_handler())
425            .request_schema::<Sample>()
426            .response_schema::<Sample>()];
427
428        let document = OpenApi::new().build_document(&routes);
429
430        // The model is registered once under components.schemas.
431        assert!(
432            document["components"]["schemas"]["Sample"].is_object(),
433            "document: {document}"
434        );
435
436        let operation = &document["paths"]["/samples"]["post"];
437        let request_ref =
438            &operation["requestBody"]["content"]["application/json"]["schema"]["$ref"];
439        let response_ref =
440            &operation["responses"]["200"]["content"]["application/json"]["schema"]["$ref"];
441        assert_eq!(request_ref, "#/components/schemas/Sample");
442        assert_eq!(response_ref, "#/components/schemas/Sample");
443    }
444
445    #[test]
446    fn multipart_route_documents_form_data_with_binary_file() {
447        // A form schema thunk shaped like the one generated by #[derive(FormModel)]:
448        // a text field plus a binary file field.
449        fn form_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
450            schemars::Schema::try_from(json!({
451                "type": "object",
452                "properties": {
453                    "token": { "type": "string" },
454                    "file": { "type": "string", "format": "binary" },
455                },
456                "required": ["token", "file"],
457            }))
458            .unwrap()
459        }
460
461        let routes = vec![Route::new(Method::POST, "/files", dummy_handler())
462            .request_schema_fn(form_schema)
463            .request_kind(RequestBodyKind::Multipart)];
464
465        let document = OpenApi::new().build_document(&routes);
466        let content = &document["paths"]["/files"]["post"]["requestBody"]["content"];
467
468        let schema = &content["multipart/form-data"]["schema"];
469        assert_eq!(schema["properties"]["file"]["format"], "binary");
470        assert!(
471            content["application/json"].is_null(),
472            "multipart body must not be JSON: {content}"
473        );
474    }
475
476    #[test]
477    fn urlencoded_route_documents_form_content_type() {
478        let routes = vec![Route::new(Method::POST, "/login", dummy_handler())
479            .request_schema::<Sample>()
480            .request_kind(RequestBodyKind::Form)];
481
482        let document = OpenApi::new().build_document(&routes);
483        let content = &document["paths"]["/login"]["post"]["requestBody"]["content"];
484
485        assert!(
486            content["application/x-www-form-urlencoded"]["schema"].is_object(),
487            "expected urlencoded body: {content}"
488        );
489        assert!(content["application/json"].is_null());
490    }
491
492    #[test]
493    fn streaming_route_documents_event_stream() {
494        let routes = vec![Route::new(Method::GET, "/stream", dummy_handler())
495            .response_schema::<Sample>()
496            .streaming()];
497
498        let document = OpenApi::new().build_document(&routes);
499        let response = &document["paths"]["/stream"]["get"]["responses"]["200"];
500
501        assert_eq!(response["description"], "Server-Sent Events stream");
502        assert_eq!(
503            response["content"]["text/event-stream"]["schema"]["$ref"],
504            "#/components/schemas/Sample"
505        );
506        assert!(
507            response["content"]["application/json"].is_null(),
508            "streaming response must not be JSON: {response}"
509        );
510    }
511
512    #[test]
513    fn provider_registers_spec_and_docs_routes() {
514        let provider = OpenApi::new()
515            .title("Docs")
516            .version("1.2.3")
517            .json("/schema.json")
518            .docs("/docs");
519
520        let routes = provider.documentation_routes(&[]);
521
522        assert_eq!(routes.len(), 2);
523        assert_eq!(routes[0].path(), "/schema.json");
524        assert_eq!(routes[1].path(), "/docs");
525    }
526
527    #[test]
528    fn operation_id_and_placeholder_helpers_cover_edge_cases() {
529        assert_eq!(operation_id("patch", "/"), "patch");
530        assert_eq!(
531            operation_id("get", "/teams/{team-id}/members/{*rest}"),
532            "get_teams__team_id__members___rest_"
533        );
534        assert_eq!(
535            placeholder_names("/teams/{team_id}/members/{*rest}"),
536            vec!["team_id".to_owned(), "rest".to_owned()]
537        );
538    }
539
540    #[test]
541    fn document_sanitizes_route_and_info_text_fields() {
542        let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
543            .summary("<script>alert(1)</script>")
544            .description("bad\u{0007}`quote`")
545            .tag("ops<script>")];
546
547        let document = OpenApi::new()
548            .title("Docs <unsafe>")
549            .description("line\u{0001}two")
550            .build_document(&routes);
551
552        let operation = &document["paths"]["/users/{user_id}"]["get"];
553        assert_eq!(
554            operation["summary"],
555            "&lt;script&gt;alert(1)&lt;/script&gt;"
556        );
557        assert_eq!(operation["description"], "bad &#x60;quote&#x60;");
558        assert_eq!(operation["tags"][0], "ops&lt;script&gt;");
559        assert_eq!(document["info"]["title"], "Docs &lt;unsafe&gt;");
560        assert_eq!(document["info"]["description"], "line two");
561    }
562}