Skip to main content

structured_proxy/
openapi.rs

1//! OpenAPI 3.0 spec generation from proto descriptors.
2//!
3//! Reads `google.api.http` annotations and proto message definitions
4//! to produce a complete OpenAPI 3.0 JSON spec at runtime.
5//! No codegen, no build step — same descriptor pool used for transcoding.
6
7use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor, MethodDescriptor};
8use serde_json::{json, Map, Value};
9
10use crate::config::{AliasConfig, OpenApiConfig};
11
12/// Generate OpenAPI 3.0 JSON spec from a descriptor pool.
13pub fn generate(pool: &DescriptorPool, config: &OpenApiConfig, aliases: &[AliasConfig]) -> Value {
14    let title = config.title.as_deref().unwrap_or("API");
15    let version = config.version.as_deref().unwrap_or("1.0.0");
16
17    let mut paths = Map::new();
18    let mut schemas = Map::new();
19    let mut tags = Vec::new();
20
21    for service in pool.services() {
22        let service_name = service.name().to_string();
23        let service_full = service.full_name().to_string();
24
25        // Proto comments as tag description.
26        let tag_desc = get_comments(&service_full, pool);
27        let mut tag = json!({ "name": service_name });
28        if let Some(desc) = &tag_desc {
29            tag["description"] = json!(desc);
30        }
31        tags.push(tag);
32
33        for method in service.methods() {
34            if method.is_client_streaming() {
35                continue; // No REST mapping for client-streaming.
36            }
37
38            if let Some((http_method, http_path)) = extract_http_rule(&method, pool) {
39                let operation = build_operation(
40                    &method,
41                    &service_name,
42                    &http_method,
43                    &http_path,
44                    pool,
45                    &mut schemas,
46                );
47
48                // Main path.
49                add_path_operation(&mut paths, &http_path, &http_method, operation.clone());
50
51                // Aliases.
52                for alias in aliases {
53                    if let Some(suffix) = http_path.strip_prefix(&alias.to) {
54                        if alias.from.ends_with("/{path}") {
55                            let prefix = alias.from.trim_end_matches("/{path}");
56                            let alias_path = format!("{}{}", prefix, suffix);
57                            add_path_operation(
58                                &mut paths,
59                                &alias_path,
60                                &http_method,
61                                operation.clone(),
62                            );
63                        }
64                    }
65                }
66            }
67        }
68    }
69
70    let mut spec = json!({
71        "openapi": "3.0.3",
72        "info": {
73            "title": title,
74            "version": version,
75        },
76        "paths": paths,
77        "tags": tags,
78    });
79
80    if !schemas.is_empty() {
81        spec["components"] = json!({
82            "schemas": schemas,
83        });
84    }
85
86    // Security scheme for Bearer auth (cookie auth works implicitly via same-origin).
87    spec["components"]["securitySchemes"] = json!({
88        "bearerAuth": {
89            "type": "http",
90            "scheme": "bearer",
91            "bearerFormat": "JWT",
92        },
93    });
94
95    spec
96}
97
98/// Generate Scalar API docs HTML page.
99pub fn docs_html(openapi_path: &str, title: &str) -> String {
100    format!(
101        r#"<!DOCTYPE html>
102<html>
103<head>
104    <title>{title} — API Docs</title>
105    <meta charset="utf-8" />
106    <meta name="viewport" content="width=device-width, initial-scale=1" />
107</head>
108<body>
109    <script id="api-reference" data-url="{openapi_path}"></script>
110    <script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
111</body>
112</html>"#,
113        title = title,
114        openapi_path = openapi_path,
115    )
116}
117
118fn add_path_operation(paths: &mut Map<String, Value>, path: &str, method: &str, operation: Value) {
119    let path_item = paths.entry(path.to_string()).or_insert_with(|| json!({}));
120    if let Some(obj) = path_item.as_object_mut() {
121        obj.insert(method.to_string(), operation);
122    }
123}
124
125fn build_operation(
126    method: &MethodDescriptor,
127    service_name: &str,
128    http_method: &str,
129    http_path: &str,
130    pool: &DescriptorPool,
131    schemas: &mut Map<String, Value>,
132) -> Value {
133    let method_name = method.name().to_string();
134    let full_name = method.full_name().to_string();
135    let input = method.input();
136    let output = method.output();
137
138    let is_streaming = method.is_server_streaming();
139
140    // Description from proto comments.
141    let description = get_comments(&full_name, pool).unwrap_or_default();
142
143    let operation_id = format!("{}.{}", service_name, method_name);
144
145    let mut op = json!({
146        "operationId": operation_id,
147        "tags": [service_name],
148        "summary": method_name,
149    });
150
151    if !description.is_empty() {
152        op["description"] = json!(description);
153    }
154
155    // Path parameters.
156    let path_params = extract_path_params(http_path);
157    if !path_params.is_empty() {
158        let params: Vec<Value> = path_params
159            .iter()
160            .map(|name| {
161                let mut param = json!({
162                    "name": name,
163                    "in": "path",
164                    "required": true,
165                    "schema": { "type": "string" },
166                });
167
168                // Try to get type from input message field.
169                if let Some(field) = input.get_field_by_name(name) {
170                    param["schema"] = field_to_schema(&field);
171                }
172
173                param
174            })
175            .collect();
176        op["parameters"] = json!(params);
177    }
178
179    // Request body (for POST/PUT/PATCH/DELETE with body fields).
180    if http_method != "get" {
181        let has_body_fields = input
182            .fields()
183            .any(|f| !path_params.contains(&f.name().to_string()));
184
185        if has_body_fields {
186            let schema_name = input.name().to_string();
187            let body_schema = message_to_schema(&input, &path_params, schemas);
188
189            schemas.insert(schema_name.clone(), body_schema);
190
191            op["requestBody"] = json!({
192                "required": true,
193                "content": {
194                    "application/json": {
195                        "schema": {
196                            "$ref": format!("#/components/schemas/{}", schema_name),
197                        },
198                    },
199                },
200            });
201        }
202    } else {
203        // GET: non-path fields become query parameters.
204        let query_params: Vec<Value> = input
205            .fields()
206            .filter(|f| !path_params.contains(&f.name().to_string()))
207            .map(|field| {
208                json!({
209                    "name": field.name(),
210                    "in": "query",
211                    "required": false,
212                    "schema": field_to_schema(&field),
213                })
214            })
215            .collect();
216
217        if !query_params.is_empty() {
218            let existing = op
219                .get("parameters")
220                .and_then(|v| v.as_array())
221                .cloned()
222                .unwrap_or_default();
223            let mut all_params = existing;
224            all_params.extend(query_params);
225            op["parameters"] = json!(all_params);
226        }
227    }
228
229    // Response.
230    if is_streaming {
231        op["responses"] = json!({
232            "200": {
233                "description": "Server-streaming response (NDJSON)",
234                "content": {
235                    "application/x-ndjson": {
236                        "schema": message_ref_or_inline(&output, schemas),
237                    },
238                },
239            },
240        });
241    } else if output.full_name() == "google.protobuf.Empty" {
242        op["responses"] = json!({
243            "200": {
244                "description": "Success (empty response)",
245            },
246        });
247    } else {
248        let schema_name = output.name().to_string();
249        let response_schema = message_to_schema(&output, &[], schemas);
250        schemas.insert(schema_name.clone(), response_schema);
251
252        op["responses"] = json!({
253            "200": {
254                "description": "Success",
255                "content": {
256                    "application/json": {
257                        "schema": {
258                            "$ref": format!("#/components/schemas/{}", schema_name),
259                        },
260                    },
261                },
262            },
263        });
264    }
265
266    // Common error responses.
267    if let Some(responses) = op.get_mut("responses").and_then(|r| r.as_object_mut()) {
268        responses.insert(
269            "400".to_string(),
270            json!({ "description": "Invalid argument" }),
271        );
272        responses.insert(
273            "401".to_string(),
274            json!({ "description": "Unauthenticated" }),
275        );
276        responses.insert(
277            "403".to_string(),
278            json!({ "description": "Permission denied" }),
279        );
280        responses.insert("404".to_string(), json!({ "description": "Not found" }));
281        responses.insert(
282            "503".to_string(),
283            json!({ "description": "Service unavailable" }),
284        );
285    }
286
287    op
288}
289
290/// Generate a JSON Schema for a protobuf message, excluding path parameter fields.
291fn message_to_schema(
292    msg: &MessageDescriptor,
293    exclude_fields: &[String],
294    schemas: &mut Map<String, Value>,
295) -> Value {
296    let mut properties = Map::new();
297    let required: Vec<String> = Vec::new();
298
299    for field in msg.fields() {
300        let name = field.name().to_string();
301        if exclude_fields.contains(&name) {
302            continue;
303        }
304
305        let schema = field_to_schema(&field);
306        properties.insert(name, schema);
307    }
308
309    let mut schema = json!({
310        "type": "object",
311        "properties": properties,
312    });
313
314    if !required.is_empty() {
315        schema["required"] = json!(required);
316    }
317
318    // Nested messages: register as separate schemas.
319    for field in msg.fields() {
320        if exclude_fields.contains(&field.name().to_string()) {
321            continue;
322        }
323        if let Kind::Message(nested) = field.kind() {
324            if !is_well_known(&nested) && !schemas.contains_key(nested.name()) {
325                let nested_schema = message_to_schema(&nested, &[], schemas);
326                schemas.insert(nested.name().to_string(), nested_schema);
327            }
328        }
329    }
330
331    schema
332}
333
334fn message_ref_or_inline(msg: &MessageDescriptor, schemas: &mut Map<String, Value>) -> Value {
335    let name = msg.name().to_string();
336    if !schemas.contains_key(&name) {
337        let schema = message_to_schema(msg, &[], schemas);
338        schemas.insert(name.clone(), schema);
339    }
340    json!({ "$ref": format!("#/components/schemas/{}", name) })
341}
342
343fn field_to_schema(field: &FieldDescriptor) -> Value {
344    let base = match field.kind() {
345        Kind::Double | Kind::Float => json!({ "type": "number", "format": "double" }),
346        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
347            json!({ "type": "integer", "format": "int32" })
348        }
349        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
350            json!({ "type": "string", "format": "int64", "description": "64-bit integer (string-encoded)" })
351        }
352        Kind::Uint32 | Kind::Fixed32 => {
353            json!({ "type": "integer", "format": "uint32" })
354        }
355        Kind::Uint64 | Kind::Fixed64 => {
356            json!({ "type": "string", "format": "uint64", "description": "64-bit unsigned integer (string-encoded)" })
357        }
358        Kind::Bool => json!({ "type": "boolean" }),
359        Kind::String => json!({ "type": "string" }),
360        Kind::Bytes => json!({ "type": "string", "format": "byte" }),
361        Kind::Enum(e) => {
362            let values: Vec<Value> = e.values().map(|v| json!(v.name())).collect();
363            json!({ "type": "string", "enum": values })
364        }
365        Kind::Message(msg) => {
366            if is_well_known(&msg) {
367                well_known_schema(&msg)
368            } else {
369                json!({ "$ref": format!("#/components/schemas/{}", msg.name()) })
370            }
371        }
372    };
373
374    if field.is_list() {
375        json!({ "type": "array", "items": base })
376    } else if field.is_map() {
377        // Map<K, V> → object with additionalProperties.
378        if let Kind::Message(entry) = field.kind() {
379            let value_field = entry.get_field_by_name("value");
380            let value_schema = value_field
381                .map(|f| field_to_schema(&f))
382                .unwrap_or_else(|| json!({}));
383            json!({ "type": "object", "additionalProperties": value_schema })
384        } else {
385            json!({ "type": "object" })
386        }
387    } else {
388        base
389    }
390}
391
392fn is_well_known(msg: &MessageDescriptor) -> bool {
393    msg.full_name().starts_with("google.protobuf.")
394}
395
396fn well_known_schema(msg: &MessageDescriptor) -> Value {
397    match msg.full_name() {
398        "google.protobuf.Timestamp" => {
399            json!({ "type": "string", "format": "date-time" })
400        }
401        "google.protobuf.Duration" => {
402            json!({ "type": "string", "format": "duration", "example": "3.5s" })
403        }
404        "google.protobuf.Empty" => json!({ "type": "object" }),
405        "google.protobuf.Struct" => json!({ "type": "object" }),
406        "google.protobuf.Value" => json!({}),
407        "google.protobuf.ListValue" => json!({ "type": "array", "items": {} }),
408        "google.protobuf.StringValue" | "google.protobuf.BytesValue" => {
409            json!({ "type": "string" })
410        }
411        "google.protobuf.BoolValue" => json!({ "type": "boolean" }),
412        "google.protobuf.Int32Value" | "google.protobuf.UInt32Value" => {
413            json!({ "type": "integer" })
414        }
415        "google.protobuf.Int64Value" | "google.protobuf.UInt64Value" => {
416            json!({ "type": "string", "format": "int64" })
417        }
418        "google.protobuf.FloatValue" | "google.protobuf.DoubleValue" => {
419            json!({ "type": "number" })
420        }
421        "google.protobuf.FieldMask" => {
422            json!({ "type": "string", "description": "Comma-separated field paths" })
423        }
424        "google.protobuf.Any" => {
425            json!({ "type": "object", "properties": { "@type": { "type": "string" } }, "additionalProperties": true })
426        }
427        _ => json!({ "type": "object" }),
428    }
429}
430
431/// Extract `{param}` names from a path like `/v1/profiles/{profile_id}/devices`.
432fn extract_path_params(path: &str) -> Vec<String> {
433    let mut params = Vec::new();
434    let mut in_brace = false;
435    let mut current = String::new();
436
437    for ch in path.chars() {
438        match ch {
439            '{' => {
440                in_brace = true;
441                current.clear();
442            }
443            '}' => {
444                in_brace = false;
445                if !current.is_empty() {
446                    params.push(current.clone());
447                }
448            }
449            _ if in_brace => current.push(ch),
450            _ => {}
451        }
452    }
453
454    params
455}
456
457/// Extract HTTP method and path from google.api.http annotation.
458fn extract_http_rule(method: &MethodDescriptor, pool: &DescriptorPool) -> Option<(String, String)> {
459    let http_ext = pool.get_extension_by_name("google.api.http")?;
460    let options = method.options();
461
462    if !options.has_extension(&http_ext) {
463        return None;
464    }
465
466    let http_rule = options.get_extension(&http_ext);
467    if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
468        for (method_name, _) in [
469            ("get", "get"),
470            ("post", "post"),
471            ("put", "put"),
472            ("delete", "delete"),
473            ("patch", "patch"),
474        ] {
475            if let Some(val) = rule_msg.get_field_by_name(method_name) {
476                if let prost_reflect::Value::String(path) = val.into_owned() {
477                    if !path.is_empty() {
478                        return Some((method_name.to_string(), path));
479                    }
480                }
481            }
482        }
483    }
484
485    None
486}
487
488/// Get proto source comments for a given fully-qualified name.
489fn get_comments(_full_name: &str, _pool: &DescriptorPool) -> Option<String> {
490    // prost-reflect doesn't expose source code info comments easily.
491    // For now, return None. Can be enhanced with protoc-gen-doc or
492    // manual SourceCodeInfo parsing.
493    None
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_extract_path_params() {
502        assert_eq!(
503            extract_path_params("/v1/profiles/{profile_id}"),
504            vec!["profile_id"]
505        );
506        assert_eq!(
507            extract_path_params("/v1/profiles/{profile_id}/devices/{device_id}"),
508            vec!["profile_id", "device_id"]
509        );
510        assert!(extract_path_params("/v1/auth/login").is_empty());
511    }
512
513    #[test]
514    fn test_docs_html_contains_scalar() {
515        let html = docs_html("/openapi.json", "Test API");
516        assert!(html.contains("@scalar/api-reference"));
517        assert!(html.contains("/openapi.json"));
518        assert!(html.contains("Test API"));
519    }
520
521    #[test]
522    fn test_well_known_schemas() {
523        // Verify well-known type mappings are correct.
524        let pool = DescriptorPool::global();
525        if let Some(ts) = pool.get_message_by_name("google.protobuf.Timestamp") {
526            let schema = well_known_schema(&ts);
527            assert_eq!(schema["type"], "string");
528            assert_eq!(schema["format"], "date-time");
529        }
530    }
531
532    #[test]
533    fn test_generate_empty_pool() {
534        let pool = DescriptorPool::new();
535        let config = OpenApiConfig {
536            enabled: true,
537            path: "/openapi.json".into(),
538            docs_path: "/docs".into(),
539            title: Some("Test API".into()),
540            version: Some("0.1.0".into()),
541        };
542        let spec = generate(&pool, &config, &[]);
543
544        assert_eq!(spec["openapi"], "3.0.3");
545        assert_eq!(spec["info"]["title"], "Test API");
546        assert_eq!(spec["info"]["version"], "0.1.0");
547        assert!(spec["paths"].as_object().unwrap().is_empty());
548    }
549
550    #[test]
551    fn test_field_to_schema_primitives() {
552        // Test via JSON output structure.
553        let schema = json!({ "type": "string" });
554        assert_eq!(schema["type"], "string");
555
556        let int_schema = json!({ "type": "integer", "format": "int32" });
557        assert_eq!(int_schema["format"], "int32");
558
559        let i64_schema = json!({ "type": "string", "format": "int64", "description": "64-bit integer (string-encoded)" });
560        assert_eq!(i64_schema["type"], "string");
561        assert_eq!(i64_schema["format"], "int64");
562    }
563}