Skip to main content

spikard_codegen/sql/
openapi.rs

1//! Emit an OpenAPI 3.1 document from a slice of [`SqlRoute`].
2//!
3//! The spec is built as a raw `serde_json::Value` rather than reusing
4//! `crate::openapi::OpenApiSpec` because the existing struct is a subset that
5//! lacks several 3.1 idioms we need (array-typed `type`, `oneOf` for
6//! nullability, `enum`). Emitting as `Value` keeps this module decoupled and
7//! the output round-trips through any OpenAPI 3.1 consumer.
8
9use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11use serde_json::{Map, Value, json};
12
13use super::annotations::{ApiKeyLocation, AuthRequirement, HttpMethod, HttpParamBinding};
14use super::route::SqlRoute;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OpenApiInfo {
18    pub title: String,
19    pub version: String,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub description: Option<String>,
22}
23
24impl OpenApiInfo {
25    pub fn new(title: impl Into<String>, version: impl Into<String>) -> Self {
26        Self {
27            title: title.into(),
28            version: version.into(),
29            description: None,
30        }
31    }
32}
33
34/// Build an OpenAPI 3.1 document from a list of SQL-derived routes. The
35/// returned `Value` is ready to be `serde_json::to_writer_pretty`-ed to disk.
36pub fn openapi_from_routes(routes: &[SqlRoute], info: &OpenApiInfo) -> Value {
37    // Collect security schemes (one per distinct auth requirement) so we can
38    // reference them by name on per-operation `security` lists.
39    let (security_schemes, scheme_names) = collect_security_schemes(routes);
40
41    // Group operations by path so multiple methods on the same path share a
42    // `PathItem`. Using IndexMap keeps insertion order stable for snapshot
43    // testing.
44    let mut paths: IndexMap<String, Map<String, Value>> = IndexMap::new();
45    for route in routes {
46        let entry = paths.entry(route.http.path.clone()).or_default();
47        let operation = build_operation(route, &scheme_names);
48        entry.insert(method_key(route.http.method).to_string(), operation);
49    }
50
51    let mut paths_obj = Map::new();
52    for (path, methods) in paths {
53        paths_obj.insert(path, Value::Object(methods));
54    }
55
56    let mut spec = Map::new();
57    spec.insert("openapi".into(), json!("3.1.0"));
58    spec.insert("info".into(), serde_json::to_value(info).expect("info serializes"));
59    spec.insert("paths".into(), Value::Object(paths_obj));
60
61    let mut components = Map::new();
62    if !security_schemes.is_empty() {
63        components.insert("securitySchemes".into(), Value::Object(security_schemes));
64    }
65    if !components.is_empty() {
66        spec.insert("components".into(), Value::Object(components));
67    }
68
69    Value::Object(spec)
70}
71
72fn build_operation(route: &SqlRoute, scheme_names: &std::collections::BTreeMap<AuthRequirement, String>) -> Value {
73    let mut op = Map::new();
74    op.insert("operationId".into(), json!(&route.operation_id));
75
76    if let Some(s) = &route.http.summary {
77        op.insert("summary".into(), json!(s));
78    }
79    if let Some(d) = &route.http.description {
80        op.insert("description".into(), json!(d));
81    }
82    if !route.http.tags.is_empty() {
83        op.insert("tags".into(), json!(&route.http.tags));
84    }
85
86    let parameters = build_parameters(route);
87    if !parameters.is_empty() {
88        op.insert("parameters".into(), Value::Array(parameters));
89    }
90
91    if let Some(request_body) = build_request_body(route) {
92        op.insert("requestBody".into(), request_body);
93    }
94
95    op.insert("responses".into(), build_responses(route));
96
97    if let Some(auth) = &route.http.auth
98        && !matches!(auth, AuthRequirement::None)
99        && let Some(name) = scheme_names.get(auth)
100    {
101        op.insert("security".into(), json!([{ name.as_str(): [] }]));
102    }
103
104    Value::Object(op)
105}
106
107fn build_parameters(route: &SqlRoute) -> Vec<Value> {
108    let mut out = Vec::new();
109    let parameter_schema = &route.metadata["parameter_schema"];
110    let properties = parameter_schema.get("properties").and_then(Value::as_object);
111    let Some(properties) = properties else {
112        return out;
113    };
114    let required: std::collections::HashSet<&str> = parameter_schema
115        .get("required")
116        .and_then(Value::as_array)
117        .map(|arr| arr.iter().filter_map(Value::as_str).collect())
118        .unwrap_or_default();
119
120    for (name, schema) in properties {
121        let location = match route.param_locations.get(name) {
122            Some(HttpParamBinding::Path) => "path",
123            Some(HttpParamBinding::Query) => "query",
124            Some(HttpParamBinding::Header) => "header",
125            _ => continue,
126        };
127        let is_required = location == "path" || required.contains(name.as_str());
128        let mut p = Map::new();
129        p.insert("name".into(), json!(name));
130        p.insert("in".into(), json!(location));
131        p.insert("required".into(), json!(is_required));
132        p.insert("schema".into(), schema.clone());
133        out.push(Value::Object(p));
134    }
135    out
136}
137
138fn build_request_body(route: &SqlRoute) -> Option<Value> {
139    let request_schema = route.metadata.get("request_schema")?;
140    if request_schema.is_null() {
141        return None;
142    }
143    Some(json!({
144        "required": true,
145        "content": {
146            "application/json": { "schema": request_schema }
147        }
148    }))
149}
150
151fn build_responses(route: &SqlRoute) -> Value {
152    let mut responses = Map::new();
153    let response_schema = route.metadata.get("response_schema").cloned().unwrap_or(Value::Null);
154    let codes: Vec<u16> = if route.http.status_codes.is_empty() {
155        vec![route.default_status]
156    } else {
157        route.http.status_codes.clone()
158    };
159    for (idx, code) in codes.iter().enumerate() {
160        let is_primary = idx == 0;
161        let mut body = Map::new();
162        body.insert("description".into(), json!(describe_status(*code)));
163        if is_primary && !response_schema.is_null() && *code != 204 {
164            body.insert(
165                "content".into(),
166                json!({ "application/json": { "schema": response_schema.clone() } }),
167            );
168        }
169        responses.insert(code.to_string(), Value::Object(body));
170    }
171    Value::Object(responses)
172}
173
174const fn describe_status(code: u16) -> &'static str {
175    match code {
176        200 => "OK",
177        201 => "Created",
178        202 => "Accepted",
179        204 => "No Content",
180        400 => "Bad Request",
181        401 => "Unauthorized",
182        403 => "Forbidden",
183        404 => "Not Found",
184        409 => "Conflict",
185        422 => "Unprocessable Entity",
186        500 => "Internal Server Error",
187        _ => "Response",
188    }
189}
190
191fn collect_security_schemes(
192    routes: &[SqlRoute],
193) -> (Map<String, Value>, std::collections::BTreeMap<AuthRequirement, String>) {
194    let mut schemes = Map::new();
195    let mut name_for = std::collections::BTreeMap::new();
196    for route in routes {
197        let Some(auth) = &route.http.auth else { continue };
198        if matches!(auth, AuthRequirement::None) {
199            continue;
200        }
201        if name_for.contains_key(auth) {
202            continue;
203        }
204        let name = match auth {
205            AuthRequirement::None => unreachable!(),
206            AuthRequirement::Bearer { format: None } => "bearerAuth".to_string(),
207            AuthRequirement::Bearer { format: Some(f) } => format!("bearer{}", f.to_uppercase()),
208            AuthRequirement::ApiKey { location, name } => {
209                format!("apiKey_{}_{}", location_short(*location), name.replace('-', "_"))
210            }
211        };
212        let scheme_value = match auth {
213            AuthRequirement::None => unreachable!(),
214            AuthRequirement::Bearer { format } => {
215                let mut s = Map::new();
216                s.insert("type".into(), json!("http"));
217                s.insert("scheme".into(), json!("bearer"));
218                if let Some(f) = format {
219                    s.insert("bearerFormat".into(), json!(f));
220                }
221                Value::Object(s)
222            }
223            AuthRequirement::ApiKey { location, name } => json!({
224                "type": "apiKey",
225                "in": location_str(*location),
226                "name": name,
227            }),
228        };
229        schemes.insert(name.clone(), scheme_value);
230        name_for.insert(auth.clone(), name);
231    }
232    (schemes, name_for)
233}
234
235const fn location_short(loc: ApiKeyLocation) -> &'static str {
236    match loc {
237        ApiKeyLocation::Header => "h",
238        ApiKeyLocation::Query => "q",
239        ApiKeyLocation::Cookie => "c",
240    }
241}
242
243const fn location_str(loc: ApiKeyLocation) -> &'static str {
244    match loc {
245        ApiKeyLocation::Header => "header",
246        ApiKeyLocation::Query => "query",
247        ApiKeyLocation::Cookie => "cookie",
248    }
249}
250
251const fn method_key(m: HttpMethod) -> &'static str {
252    match m {
253        HttpMethod::Get => "get",
254        HttpMethod::Post => "post",
255        HttpMethod::Put => "put",
256        HttpMethod::Patch => "patch",
257        HttpMethod::Delete => "delete",
258        HttpMethod::Head => "head",
259        HttpMethod::Options => "options",
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::sql::neutral_to_json_schema::BuildOptions;
267    use crate::sql::route::route_from_query;
268    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
269    use scythe_core::catalog::Catalog;
270    use scythe_core::parser::{CustomAnnotation, QueryCommand};
271
272    fn empty_catalog() -> Catalog {
273        Catalog::from_ddl(&[]).unwrap()
274    }
275
276    fn get_user_query() -> AnalyzedQuery {
277        AnalyzedQuery {
278            name: "GetUser".to_string(),
279            command: QueryCommand::One,
280            sql: "SELECT id, email FROM users WHERE id = $1".to_string(),
281            columns: vec![
282                AnalyzedColumn {
283                    name: "id".into(),
284                    neutral_type: "int64".into(),
285                    nullable: false,
286                },
287                AnalyzedColumn {
288                    name: "email".into(),
289                    neutral_type: "string".into(),
290                    nullable: false,
291                },
292            ],
293            params: vec![AnalyzedParam {
294                name: "id".into(),
295                neutral_type: "int64".into(),
296                nullable: false,
297                position: 1,
298            }],
299            deprecated: None,
300            source_table: Some("users".into()),
301            composites: vec![],
302            enums: vec![],
303            optional_params: vec![],
304            group_by: None,
305            custom: vec![
306                CustomAnnotation {
307                    name: "http".into(),
308                    value: "GET /users/{id}".into(),
309                    line: 1,
310                },
311                CustomAnnotation {
312                    name: "http_auth".into(),
313                    value: "bearer:jwt".into(),
314                    line: 2,
315                },
316                CustomAnnotation {
317                    name: "http_status".into(),
318                    value: "200,404".into(),
319                    line: 3,
320                },
321                CustomAnnotation {
322                    name: "http_tags".into(),
323                    value: "users".into(),
324                    line: 4,
325                },
326                CustomAnnotation {
327                    name: "http_summary".into(),
328                    value: "Fetch a user".into(),
329                    line: 5,
330                },
331            ],
332        }
333    }
334
335    fn create_user_query() -> AnalyzedQuery {
336        AnalyzedQuery {
337            name: "CreateUser".to_string(),
338            command: QueryCommand::ExecRows,
339            sql: "INSERT INTO users (email) VALUES ($1)".to_string(),
340            columns: vec![],
341            params: vec![AnalyzedParam {
342                name: "email".into(),
343                neutral_type: "string".into(),
344                nullable: false,
345                position: 1,
346            }],
347            deprecated: None,
348            source_table: None,
349            composites: vec![],
350            enums: vec![],
351            optional_params: vec![],
352            group_by: None,
353            custom: vec![
354                CustomAnnotation {
355                    name: "http".into(),
356                    value: "POST /users".into(),
357                    line: 1,
358                },
359                CustomAnnotation {
360                    name: "http_auth".into(),
361                    value: "bearer:jwt".into(),
362                    line: 2,
363                },
364                CustomAnnotation {
365                    name: "http_status".into(),
366                    value: "201".into(),
367                    line: 3,
368                },
369            ],
370        }
371    }
372
373    fn build_two_routes() -> Vec<SqlRoute> {
374        let opts = BuildOptions::default();
375        let r1 = route_from_query(&get_user_query(), &empty_catalog(), &opts)
376            .unwrap()
377            .unwrap();
378        let r2 = route_from_query(&create_user_query(), &empty_catalog(), &opts)
379            .unwrap()
380            .unwrap();
381        vec![r1, r2]
382    }
383
384    #[test]
385    fn emits_openapi_3_1_header() {
386        let routes = build_two_routes();
387        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("test", "0.1.0"));
388        assert_eq!(spec["openapi"], "3.1.0");
389        assert_eq!(spec["info"]["title"], "test");
390        assert_eq!(spec["info"]["version"], "0.1.0");
391    }
392
393    #[test]
394    fn groups_methods_under_shared_path() {
395        let routes = build_two_routes();
396        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
397        // /users has POST; /users/{id} has GET.
398        assert!(spec["paths"]["/users"]["post"].is_object());
399        assert!(spec["paths"]["/users/{id}"]["get"].is_object());
400    }
401
402    #[test]
403    fn operation_carries_operation_id_summary_tags() {
404        let routes = build_two_routes();
405        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
406        let op = &spec["paths"]["/users/{id}"]["get"];
407        assert_eq!(op["operationId"], "GetUser");
408        assert_eq!(op["summary"], "Fetch a user");
409        assert_eq!(op["tags"], json!(["users"]));
410    }
411
412    #[test]
413    fn path_parameter_emitted() {
414        let routes = build_two_routes();
415        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
416        let params = spec["paths"]["/users/{id}"]["get"]["parameters"].as_array().unwrap();
417        assert_eq!(params.len(), 1);
418        assert_eq!(params[0]["name"], "id");
419        assert_eq!(params[0]["in"], "path");
420        assert_eq!(params[0]["required"], true);
421    }
422
423    #[test]
424    fn post_carries_request_body() {
425        let routes = build_two_routes();
426        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
427        let body = &spec["paths"]["/users"]["post"]["requestBody"];
428        assert_eq!(body["required"], true);
429        assert!(body["content"]["application/json"]["schema"]["properties"]["email"].is_object());
430    }
431
432    #[test]
433    fn responses_keyed_by_status_codes() {
434        let routes = build_two_routes();
435        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
436        let resp = &spec["paths"]["/users/{id}"]["get"]["responses"];
437        assert!(resp["200"].is_object());
438        assert!(resp["404"].is_object());
439    }
440
441    #[test]
442    fn primary_response_includes_schema() {
443        let routes = build_two_routes();
444        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
445        let primary = &spec["paths"]["/users/{id}"]["get"]["responses"]["200"];
446        assert!(primary["content"]["application/json"]["schema"]["properties"]["id"].is_object());
447    }
448
449    #[test]
450    fn registers_bearer_security_scheme_once() {
451        let routes = build_two_routes();
452        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
453        let schemes = &spec["components"]["securitySchemes"];
454        // Both routes share `bearer:jwt`, so exactly one scheme is registered.
455        assert_eq!(schemes.as_object().unwrap().len(), 1);
456        let (_name, scheme) = schemes.as_object().unwrap().iter().next().unwrap();
457        assert_eq!(scheme["type"], "http");
458        assert_eq!(scheme["scheme"], "bearer");
459        assert_eq!(scheme["bearerFormat"], "jwt");
460    }
461
462    #[test]
463    fn operations_reference_security_scheme() {
464        let routes = build_two_routes();
465        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
466        let op = &spec["paths"]["/users/{id}"]["get"];
467        let sec = op["security"].as_array().unwrap();
468        assert_eq!(sec.len(), 1);
469        let scheme_name = sec[0].as_object().unwrap().keys().next().unwrap();
470        // The name must exist in components.securitySchemes.
471        assert!(spec["components"]["securitySchemes"][scheme_name].is_object());
472    }
473
474    #[test]
475    fn no_204_response_carries_body() {
476        let mut q = create_user_query();
477        // Use :exec instead of :exec_rows to get the 204-default path.
478        q.command = QueryCommand::Exec;
479        // adjust the @http_status to omit explicit codes
480        q.custom.retain(|a| a.name != "http_status");
481        let route = route_from_query(&q, &empty_catalog(), &BuildOptions::default())
482            .unwrap()
483            .unwrap();
484        let spec = openapi_from_routes(&[route], &OpenApiInfo::new("t", "1"));
485        let resp = &spec["paths"]["/users"]["post"]["responses"]["204"];
486        assert!(resp["content"].is_null());
487    }
488}