Skip to main content

spikard_cli/codegen/
typescript.rs

1//! TypeScript code generation from `OpenAPI` schemas
2
3use super::NodeDtoStyle;
4use anyhow::Result;
5use heck::{ToPascalCase, ToSnakeCase};
6use openapiv3::{
7    OpenAPI, Operation, Parameter, ParameterSchemaOrContent, ReferenceOr, Schema, SchemaKind, StringFormat, Type,
8    VariantOrUnknownOrEmpty,
9};
10use std::collections::{HashMap, HashSet, VecDeque};
11
12pub struct TypeScriptGenerator {
13    spec: OpenAPI,
14    dto: NodeDtoStyle,
15}
16
17impl TypeScriptGenerator {
18    #[must_use]
19    pub const fn new(spec: OpenAPI, dto: NodeDtoStyle) -> Self {
20        Self { spec, dto }
21    }
22
23    pub fn generate(&self) -> Result<String> {
24        let mut output = String::new();
25
26        output.push_str(&self.generate_header());
27
28        output.push_str(&self.generate_schemas()?);
29
30        output.push_str(&self.generate_routes()?);
31
32        output.push_str(&self.generate_main());
33
34        Ok(output)
35    }
36
37    fn generate_header(&self) -> String {
38        match self.dto {
39            NodeDtoStyle::Zod => format!(
40                r#"// Generated by Spikard OpenAPI code generator
41// OpenAPI Version: {}
42// Title: {}
43// DO NOT EDIT - regenerate from OpenAPI schema
44
45import {{ route }} from "spikard";
46import type {{ Body, Path, Query, Request }} from "spikard";
47import {{ z }} from "zod";
48
49"#,
50                self.spec.openapi, self.spec.info.title
51            ),
52        }
53    }
54
55    fn generate_schemas(&self) -> Result<String> {
56        let mut output = String::new();
57        output.push_str("// Zod Schemas\n\n");
58
59        if let Some(components) = &self.spec.components {
60            // Convert ReferenceOr::Item into a plain HashMap for sorting
61            let mut schemas_map: HashMap<String, Schema> = HashMap::new();
62            let mut schema_refs: HashMap<String, ReferenceOr<Schema>> = HashMap::new();
63
64            for (name, schema_ref) in &components.schemas {
65                schema_refs.insert(name.clone(), schema_ref.clone());
66                if let ReferenceOr::Item(schema) = schema_ref {
67                    schemas_map.insert(name.clone(), schema.clone());
68                }
69            }
70
71            // Topologically sort the schemas
72            let sorted_names = topological_sort_schemas(&schemas_map);
73
74            // Generate schemas in topologically sorted order
75            for name in sorted_names {
76                if let Some(ReferenceOr::Item(schema)) = schema_refs.get(&name) {
77                    output.push_str(&self.generate_zod_schema(&name, schema)?);
78                    output.push('\n');
79                }
80            }
81        }
82
83        Ok(output)
84    }
85
86    fn generate_zod_schema(&self, name: &str, schema: &Schema) -> Result<String> {
87        let schema_name = format!("{}Schema", name.to_pascal_case());
88        let type_name = name.to_pascal_case();
89        let mut output = String::new();
90
91        if let Some(description) = &schema.schema_data.description {
92            output.push_str(&format!("/** {description} */\n"));
93        }
94
95        let mut schema_expr = match &schema.schema_kind {
96            SchemaKind::Type(Type::Object(obj)) => {
97                let mut expr = String::from("z.object({\n");
98
99                for (prop_name, prop_schema_ref) in &obj.properties {
100                    let is_required = obj.required.contains(prop_name);
101                    let field_name = prop_name.to_snake_case();
102
103                    let zod_type = match prop_schema_ref {
104                        ReferenceOr::Item(prop_schema) => Self::schema_to_zod_type(prop_schema, !is_required),
105                        ReferenceOr::Reference { reference } => {
106                            let ref_name = reference.split('/').next_back().unwrap();
107                            let ref_schema = format!("{}Schema", ref_name.to_pascal_case());
108                            if is_required {
109                                ref_schema
110                            } else {
111                                format!("{ref_schema}.optional()")
112                            }
113                        }
114                    };
115
116                    expr.push_str(&format!("\t{field_name}: {zod_type},\n"));
117                }
118
119                expr.push_str("})");
120                expr
121            }
122            _ => "z.unknown()".to_string(),
123        };
124
125        if schema.schema_data.nullable {
126            schema_expr.push_str(".nullable()");
127        }
128
129        output.push_str(&format!("export const {schema_name} = {schema_expr};\n"));
130
131        output.push_str(&format!("\nexport type {type_name} = z.infer<typeof {schema_name}>;\n"));
132
133        Ok(output)
134    }
135
136    /// Extract type name from a schema reference or inline schema
137    fn extract_type_from_schema_ref(&self, schema_ref: &ReferenceOr<Schema>) -> String {
138        match schema_ref {
139            ReferenceOr::Reference { reference } => {
140                let ref_name = reference.split('/').next_back().unwrap();
141                ref_name.to_pascal_case()
142            }
143            ReferenceOr::Item(schema) => Self::schema_to_typescript_type(schema, false),
144        }
145    }
146
147    /// Extract TypeScript type from a schema reference or inline schema
148    fn extract_typescript_type_from_ref(&self, schema_ref: &ReferenceOr<Schema>) -> String {
149        match schema_ref {
150            ReferenceOr::Reference { reference } => {
151                let ref_name = reference.split('/').next_back().unwrap();
152                ref_name.to_pascal_case()
153            }
154            ReferenceOr::Item(schema) => Self::schema_to_typescript_type(schema, false),
155        }
156    }
157
158    /// Extract request body TypeScript type from operation
159    fn extract_request_body_type(&self, operation: &Operation) -> Option<String> {
160        operation.request_body.as_ref().and_then(|body_ref| match body_ref {
161            ReferenceOr::Item(request_body) => request_body.content.get("application/json").and_then(|media_type| {
162                media_type
163                    .schema
164                    .as_ref()
165                    .map(|schema_ref| self.extract_typescript_type_from_ref(schema_ref))
166            }),
167            ReferenceOr::Reference { reference } => {
168                let ref_name = reference.split('/').next_back().unwrap();
169                Some(ref_name.to_pascal_case())
170            }
171        })
172    }
173
174    /// Extract response type from operation (looks for 200/201 responses)
175    fn extract_response_type(&self, operation: &Operation) -> String {
176        use openapiv3::StatusCode;
177
178        let response = operation
179            .responses
180            .responses
181            .get(&StatusCode::Code(200))
182            .or_else(|| operation.responses.responses.get(&StatusCode::Code(201)))
183            .or_else(|| operation.responses.responses.get(&StatusCode::Range(2)));
184
185        if let Some(response_ref) = response {
186            match response_ref {
187                ReferenceOr::Item(response) => {
188                    if let Some(content) = response.content.get("application/json")
189                        && let Some(schema_ref) = &content.schema
190                    {
191                        return self.extract_type_from_schema_ref(schema_ref);
192                    }
193                }
194                ReferenceOr::Reference { reference } => {
195                    let ref_name = reference.split('/').next_back().unwrap();
196                    return ref_name.to_pascal_case();
197                }
198            }
199        }
200
201        "Record<string, unknown>".to_string()
202    }
203
204    fn schema_to_zod_type(schema: &Schema, optional: bool) -> String {
205        let mut base_type = match &schema.schema_kind {
206            SchemaKind::Type(Type::String(string_type)) => {
207                let enum_values = string_type
208                    .enumeration
209                    .iter()
210                    .flatten()
211                    .map(|value| format!("{value:?}"))
212                    .collect::<Vec<_>>();
213                if !enum_values.is_empty() {
214                    format!("z.enum([{}])", enum_values.join(", "))
215                } else {
216                    match &string_type.format {
217                        VariantOrUnknownOrEmpty::Item(StringFormat::Date) => "z.string()".to_string(),
218                        VariantOrUnknownOrEmpty::Item(StringFormat::DateTime) => "z.string().datetime()".to_string(),
219                        VariantOrUnknownOrEmpty::Unknown(format) if format == "email" => {
220                            "z.string().email()".to_string()
221                        }
222                        VariantOrUnknownOrEmpty::Unknown(format) if format == "uuid" => "z.string().uuid()".to_string(),
223                        _ => "z.string()".to_string(),
224                    }
225                }
226            }
227            SchemaKind::Type(Type::Number(_)) => "z.number()".to_string(),
228            SchemaKind::Type(Type::Integer(_)) => "z.number().int()".to_string(),
229            SchemaKind::Type(Type::Boolean(_)) => "z.boolean()".to_string(),
230            SchemaKind::Type(Type::Array(arr)) => {
231                let item_type = match &arr.items {
232                    Some(ReferenceOr::Item(item_schema)) => Self::schema_to_zod_type(item_schema, false),
233                    Some(ReferenceOr::Reference { reference }) => {
234                        let ref_name = reference.split('/').next_back().unwrap();
235                        format!("{}Schema", ref_name.to_pascal_case())
236                    }
237                    None => "z.record(z.string(), z.unknown())".to_string(),
238                };
239                format!("z.array({item_type})")
240            }
241            SchemaKind::Type(Type::Object(obj)) => {
242                if obj.properties.is_empty() {
243                    "z.record(z.string(), z.unknown())".to_string()
244                } else {
245                    let mut fields = String::from("z.object({\n");
246                    for (prop_name, prop_schema_ref) in &obj.properties {
247                        let is_required = obj.required.contains(prop_name);
248                        let field_name = prop_name.to_snake_case();
249                        let zod_type = match prop_schema_ref {
250                            ReferenceOr::Item(prop_schema) => Self::schema_to_zod_type(prop_schema, !is_required),
251                            ReferenceOr::Reference { reference } => {
252                                let ref_name = reference.split('/').next_back().unwrap();
253                                let ref_schema = format!("{}Schema", ref_name.to_pascal_case());
254                                if is_required {
255                                    ref_schema
256                                } else {
257                                    format!("{ref_schema}.optional()")
258                                }
259                            }
260                        };
261                        fields.push_str(&format!("\t{field_name}: {zod_type},\n"));
262                    }
263                    fields.push_str("})");
264                    fields
265                }
266            }
267            SchemaKind::OneOf { one_of } => {
268                let members = one_of
269                    .iter()
270                    .map(|schema_ref| match schema_ref {
271                        ReferenceOr::Item(item_schema) => Self::schema_to_zod_type(item_schema, false),
272                        ReferenceOr::Reference { reference } => {
273                            let ref_name = reference.split('/').next_back().unwrap();
274                            format!("{}Schema", ref_name.to_pascal_case())
275                        }
276                    })
277                    .collect::<Vec<_>>();
278                format!("z.union([{}])", members.join(", "))
279            }
280            SchemaKind::AnyOf { any_of } => {
281                let members = any_of
282                    .iter()
283                    .map(|schema_ref| match schema_ref {
284                        ReferenceOr::Item(item_schema) => Self::schema_to_zod_type(item_schema, false),
285                        ReferenceOr::Reference { reference } => {
286                            let ref_name = reference.split('/').next_back().unwrap();
287                            format!("{}Schema", ref_name.to_pascal_case())
288                        }
289                    })
290                    .collect::<Vec<_>>();
291                format!("z.union([{}])", members.join(", "))
292            }
293            SchemaKind::AllOf { all_of } => {
294                let members = all_of
295                    .iter()
296                    .map(|schema_ref| match schema_ref {
297                        ReferenceOr::Item(item_schema) => Self::schema_to_zod_type(item_schema, false),
298                        ReferenceOr::Reference { reference } => {
299                            let ref_name = reference.split('/').next_back().unwrap();
300                            format!("{}Schema", ref_name.to_pascal_case())
301                        }
302                    })
303                    .collect::<Vec<_>>();
304                match members.split_first() {
305                    Some((first, rest)) => rest
306                        .iter()
307                        .fold(first.clone(), |acc, member| format!("{acc}.and({member})")),
308                    None => "z.unknown()".to_string(),
309                }
310            }
311            _ => "z.unknown()".to_string(),
312        };
313
314        if schema.schema_data.nullable {
315            base_type.push_str(".nullable()");
316        }
317
318        if optional {
319            base_type.push_str(".optional()");
320        }
321
322        base_type
323    }
324
325    fn schema_to_typescript_type(schema: &Schema, optional: bool) -> String {
326        let mut base_type = match &schema.schema_kind {
327            SchemaKind::Type(Type::String(string_type)) => {
328                let enum_values = string_type
329                    .enumeration
330                    .iter()
331                    .flatten()
332                    .map(|value| format!("{value:?}"))
333                    .collect::<Vec<_>>();
334                if enum_values.is_empty() {
335                    "string".to_string()
336                } else {
337                    enum_values.join(" | ")
338                }
339            }
340            SchemaKind::Type(Type::Number(_) | Type::Integer(_)) => "number".to_string(),
341            SchemaKind::Type(Type::Boolean(_)) => "boolean".to_string(),
342            SchemaKind::Type(Type::Array(arr)) => {
343                let item_type = match &arr.items {
344                    Some(ReferenceOr::Item(item_schema)) => Self::schema_to_typescript_type(item_schema, false),
345                    Some(ReferenceOr::Reference { reference }) => {
346                        let ref_name = reference.split('/').next_back().unwrap();
347                        ref_name.to_pascal_case()
348                    }
349                    None => "Record<string, unknown>".to_string(),
350                };
351                let item_type = if item_type.contains(" | ") {
352                    format!("({item_type})")
353                } else {
354                    item_type
355                };
356                format!("{item_type}[]")
357            }
358            SchemaKind::Type(Type::Object(obj)) => {
359                if obj.properties.is_empty() {
360                    "Record<string, unknown>".to_string()
361                } else {
362                    let fields = obj
363                        .properties
364                        .iter()
365                        .map(|(prop_name, prop_schema_ref)| {
366                            let optional_marker = if obj.required.contains(prop_name) { "" } else { "?" };
367                            let prop_type = match prop_schema_ref {
368                                ReferenceOr::Item(prop_schema) => Self::schema_to_typescript_type(prop_schema, false),
369                                ReferenceOr::Reference { reference } => {
370                                    let ref_name = reference.split('/').next_back().unwrap();
371                                    ref_name.to_pascal_case()
372                                }
373                            };
374                            format!("{prop_name}{optional_marker}: {prop_type}")
375                        })
376                        .collect::<Vec<_>>()
377                        .join("; ");
378                    format!("{{ {fields} }}")
379                }
380            }
381            SchemaKind::OneOf { one_of } | SchemaKind::AnyOf { any_of: one_of } => one_of
382                .iter()
383                .map(|schema_ref| match schema_ref {
384                    ReferenceOr::Item(item_schema) => Self::schema_to_typescript_type(item_schema, false),
385                    ReferenceOr::Reference { reference } => {
386                        let ref_name = reference.split('/').next_back().unwrap();
387                        ref_name.to_pascal_case()
388                    }
389                })
390                .collect::<Vec<_>>()
391                .join(" | "),
392            SchemaKind::AllOf { all_of } => all_of
393                .iter()
394                .map(|schema_ref| match schema_ref {
395                    ReferenceOr::Item(item_schema) => Self::schema_to_typescript_type(item_schema, false),
396                    ReferenceOr::Reference { reference } => {
397                        let ref_name = reference.split('/').next_back().unwrap();
398                        ref_name.to_pascal_case()
399                    }
400                })
401                .collect::<Vec<_>>()
402                .join(" & "),
403            _ => "unknown".to_string(),
404        };
405
406        if schema.schema_data.nullable {
407            base_type.push_str(" | null");
408        }
409
410        if optional {
411            base_type.push_str(" | undefined");
412        }
413
414        base_type
415    }
416
417    fn generate_routes(&self) -> Result<String> {
418        let mut output = String::new();
419        output.push_str("\n// Route Handlers\n\n");
420
421        for (path, path_item_ref) in &self.spec.paths.paths {
422            let path_item = match path_item_ref {
423                ReferenceOr::Item(item) => item,
424                ReferenceOr::Reference { .. } => continue,
425            };
426
427            if let Some(op) = &path_item.get {
428                output.push_str(&self.generate_route_handler(path, "get", op)?);
429            }
430            if let Some(op) = &path_item.post {
431                output.push_str(&self.generate_route_handler(path, "post", op)?);
432            }
433            if let Some(op) = &path_item.put {
434                output.push_str(&self.generate_route_handler(path, "put", op)?);
435            }
436            if let Some(op) = &path_item.delete {
437                output.push_str(&self.generate_route_handler(path, "delete", op)?);
438            }
439            if let Some(op) = &path_item.patch {
440                output.push_str(&self.generate_route_handler(path, "patch", op)?);
441            }
442        }
443
444        Ok(output)
445    }
446
447    fn generate_route_handler(&self, path: &str, method: &str, operation: &Operation) -> Result<String> {
448        let mut output = String::new();
449
450        if let Some(summary) = &operation.summary {
451            output.push_str(&format!("/**\n * {summary}\n"));
452        } else {
453            output.push_str("/**\n");
454        }
455        output.push_str(&format!(" * Route: {} {}\n", method.to_uppercase(), path));
456        output.push_str(" */\n");
457
458        let func_name = operation
459            .operation_id
460            .as_ref()
461            .map(|id| id.to_snake_case())
462            .unwrap_or_else(|| {
463                format!(
464                    "{}_{}",
465                    method,
466                    path.replace('/', "_").replace(['{', '}'], "").trim_matches('_')
467                )
468            });
469
470        let mut path_params = Vec::new();
471        let mut query_params = Vec::new();
472
473        for param_ref in &operation.parameters {
474            if let ReferenceOr::Item(param) = param_ref {
475                match param {
476                    Parameter::Path { parameter_data, .. } => {
477                        let type_hint = Self::parameter_typescript_type(parameter_data);
478                        path_params.push((parameter_data.name.clone(), type_hint));
479                    }
480                    Parameter::Query { parameter_data, .. } => {
481                        let type_hint = Self::parameter_typescript_type(parameter_data);
482                        query_params.push((parameter_data.name.clone(), type_hint, parameter_data.required));
483                    }
484                    _ => {}
485                }
486            }
487        }
488
489        let body_type = self.extract_request_body_type(operation);
490
491        let return_type = self.extract_response_type(operation);
492
493        output.push_str(&format!("export function {func_name}(_request: Request"));
494
495        for (param_name, param_type) in &path_params {
496            output.push_str(&format!(", _{}: Path<{}>", param_name.to_snake_case(), param_type));
497        }
498
499        for (param_name, param_type, required) in &query_params {
500            if *required {
501                output.push_str(&format!(", _{}: Query<{}>", param_name.to_snake_case(), param_type));
502            } else {
503                output.push_str(&format!(
504                    ", _{}: Query<{} | undefined>",
505                    param_name.to_snake_case(),
506                    param_type
507                ));
508            }
509        }
510
511        if let Some(body_type) = &body_type {
512            output.push_str(&format!(", _body: Body<{body_type}>"));
513        }
514
515        output.push_str(&format!("): {return_type} {{\n"));
516
517        if let Some(desc) = &operation.description {
518            output.push_str(&format!("\t/**\n\t * {desc}\n\t */\n"));
519        }
520
521        output.push_str("\tthrow new Error(\"TODO: Implement this endpoint\");\n");
522        output.push_str("}\n");
523
524        output.push_str(&format!(
525            "route(\"{}\", {{ methods: [\"{}\"] }})({});\n\n",
526            path,
527            method.to_uppercase(),
528            func_name
529        ));
530
531        Ok(output)
532    }
533
534    fn generate_main(&self) -> String {
535        r"
536// Run the application
537// Note: Actual server setup depends on your runtime configuration
538"
539        .to_string()
540    }
541
542    fn parameter_typescript_type(parameter_data: &openapiv3::ParameterData) -> String {
543        match &parameter_data.format {
544            ParameterSchemaOrContent::Schema(schema_ref) => match schema_ref {
545                ReferenceOr::Item(schema) => Self::schema_to_typescript_type(schema, false),
546                ReferenceOr::Reference { reference } => {
547                    let ref_name = reference.split('/').next_back().unwrap();
548                    ref_name.to_pascal_case()
549                }
550            },
551            ParameterSchemaOrContent::Content(_) => "unknown".to_string(),
552        }
553    }
554}
555
556/// Extract all schema names that are referenced in a given schema
557fn extract_schema_dependencies(schema: &Schema) -> HashSet<String> {
558    let mut dependencies = HashSet::new();
559    extract_dependencies_recursive(schema, &mut dependencies);
560    dependencies
561}
562
563/// Recursively extract all $ref schema names from a schema
564fn extract_dependencies_recursive(schema: &Schema, deps: &mut HashSet<String>) {
565    match &schema.schema_kind {
566        SchemaKind::Type(Type::Object(obj)) => {
567            // Extract dependencies from properties
568            for (_prop_name, prop_schema_ref) in &obj.properties {
569                match prop_schema_ref {
570                    ReferenceOr::Reference { reference } => {
571                        if let Some(ref_name) = reference.split('/').next_back() {
572                            deps.insert(ref_name.to_string());
573                        }
574                    }
575                    ReferenceOr::Item(prop_schema) => {
576                        extract_dependencies_recursive(prop_schema, deps);
577                    }
578                }
579            }
580        }
581        SchemaKind::Type(Type::Array(arr)) => {
582            // Extract dependencies from array items
583            if let Some(items) = &arr.items {
584                match items {
585                    ReferenceOr::Reference { reference } => {
586                        if let Some(ref_name) = reference.split('/').next_back() {
587                            deps.insert(ref_name.to_string());
588                        }
589                    }
590                    ReferenceOr::Item(item_schema) => {
591                        extract_dependencies_recursive(item_schema, deps);
592                    }
593                }
594            }
595        }
596        _ => {}
597    }
598}
599
600/// Topologically sort schemas by their dependencies using Kahn's algorithm
601fn topological_sort_schemas(schemas: &HashMap<String, Schema>) -> Vec<String> {
602    let mut in_degree: HashMap<String, usize> = HashMap::new();
603    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
604
605    // Initialize all schemas with zero in-degree
606    for schema_name in schemas.keys() {
607        in_degree.insert(schema_name.clone(), 0);
608        graph.insert(schema_name.clone(), Vec::new());
609    }
610
611    // Build dependency graph
612    for (schema_name, schema) in schemas {
613        let deps = extract_schema_dependencies(schema);
614        for dep in deps {
615            // Only track dependencies that are defined in the schema set
616            if schemas.contains_key(&dep) {
617                // Add edge from dependency to current schema
618                graph.entry(dep).or_default().push(schema_name.clone());
619                // Increment in-degree for current schema
620                *in_degree.get_mut(schema_name).unwrap() += 1;
621            }
622        }
623    }
624
625    // Find all schemas with no incoming edges
626    let mut queue: VecDeque<String> = in_degree
627        .iter()
628        .filter(|(_, deg)| **deg == 0)
629        .map(|(name, _)| name.clone())
630        .collect();
631
632    let mut result = Vec::new();
633
634    // Process schemas in topological order
635    while let Some(node) = queue.pop_front() {
636        result.push(node.clone());
637
638        // For each neighbor of the current node
639        if let Some(neighbors) = graph.get(&node) {
640            for neighbor in neighbors {
641                let deg = in_degree.get_mut(neighbor).unwrap();
642                *deg -= 1;
643                if *deg == 0 {
644                    queue.push_back(neighbor.clone());
645                }
646            }
647        }
648    }
649
650    result
651}