torn_api_codegen/model/
object.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use indexmap::{map::Entry, IndexMap};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::Ident;
6
7use crate::openapi::r#type::OpenApiType;
8
9use super::{r#enum::Enum, ResolvedSchema, WarningReporter};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PrimitiveType {
13    Bool,
14    I32,
15    I64,
16    String,
17    Float,
18    DateTime,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum PropertyType {
23    Primitive(PrimitiveType),
24    Ref(String),
25    Enum(Enum),
26    Nested(Box<Object>),
27    Array(Box<PropertyType>),
28    Any,
29}
30
31impl PropertyType {
32    pub fn codegen(
33        &self,
34        namespace: &mut ObjectNamespace,
35        resolved: &ResolvedSchema,
36    ) -> Option<TokenStream> {
37        match self {
38            Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()),
39            Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()),
40            Self::Primitive(PrimitiveType::I64) => Some(format_ident!("i64").into_token_stream()),
41            Self::Primitive(PrimitiveType::String) => {
42                Some(format_ident!("String").into_token_stream())
43            }
44            Self::Primitive(PrimitiveType::DateTime) => {
45                Some(quote! { chrono::DateTime<chrono::Utc> })
46            }
47            Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()),
48            Self::Ref(path) => {
49                let name = path.strip_prefix("#/components/schemas/")?;
50                let name = format_ident!("{name}");
51
52                Some(quote! { crate::models::#name })
53            }
54            Self::Enum(r#enum) => {
55                let code = r#enum.codegen(resolved)?;
56                namespace.push_element(code);
57
58                let ns = namespace.get_ident();
59                let name = format_ident!("{}", r#enum.name);
60
61                Some(quote! {
62                    #ns::#name
63                })
64            }
65            Self::Array(array) => {
66                let inner_ty = array.codegen(namespace, resolved)?;
67
68                Some(quote! {
69                    Vec<#inner_ty>
70                })
71            }
72            Self::Nested(nested) => {
73                let code = nested.codegen(resolved)?;
74                namespace.push_element(code);
75
76                let ns = namespace.get_ident();
77                let name = format_ident!("{}", nested.name);
78
79                Some(quote! {
80                    #ns::#name
81                })
82            }
83            Self::Any => Some(quote! {
84                serde_json::Value
85            }),
86        }
87    }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct Property {
92    pub field_name: String,
93    pub name: String,
94    pub description: Option<String>,
95    pub required: bool,
96    pub nullable: bool,
97    pub r#type: PropertyType,
98    pub deprecated: bool,
99}
100
101impl Property {
102    pub fn from_schema(
103        name: &str,
104        required: bool,
105        schema: &OpenApiType,
106        schemas: &IndexMap<&str, OpenApiType>,
107        warnings: WarningReporter,
108    ) -> Option<Self> {
109        let name = name.to_owned();
110        let field_name = name.to_snake_case();
111        let description = schema.description.as_deref().map(ToOwned::to_owned);
112
113        match schema {
114            OpenApiType {
115                r#enum: Some(_), ..
116            } => {
117                let Some(r#enum) = Enum::from_schema(&name.clone().to_upper_camel_case(), schema)
118                else {
119                    warnings.push("Failed to create enum");
120                    return None;
121                };
122                Some(Self {
123                    r#type: PropertyType::Enum(r#enum),
124                    name,
125                    field_name,
126                    description,
127                    required,
128                    deprecated: schema.deprecated,
129                    nullable: false,
130                })
131            }
132            OpenApiType {
133                one_of: Some(types),
134                ..
135            } => match types.as_slice() {
136                [left, OpenApiType {
137                    r#type: Some("null"),
138                    ..
139                }] => {
140                    let mut inner = Self::from_schema(&name, required, left, schemas, warnings)?;
141                    inner.nullable = true;
142                    Some(inner)
143                }
144                [left @ .., OpenApiType {
145                    r#type: Some("null"),
146                    ..
147                }] => {
148                    let rest = OpenApiType {
149                        one_of: Some(left.to_owned()),
150                        ..schema.clone()
151                    };
152                    let mut inner = Self::from_schema(&name, required, &rest, schemas, warnings)?;
153                    inner.nullable = true;
154                    Some(inner)
155                }
156                cases => {
157                    let Some(r#enum) = Enum::from_one_of(&name.to_upper_camel_case(), cases) else {
158                        warnings.push("Failed to create oneOf enum");
159                        return None;
160                    };
161
162                    Some(Self {
163                        name,
164                        field_name,
165                        description,
166                        required,
167                        nullable: false,
168                        deprecated: schema.deprecated,
169                        r#type: PropertyType::Enum(r#enum),
170                    })
171                }
172            },
173            OpenApiType {
174                all_of: Some(types),
175                ..
176            } => {
177                let obj_name = name.to_upper_camel_case();
178                let composite =
179                    Object::from_all_of(&obj_name, types, schemas, warnings.child(&obj_name));
180                Some(Self {
181                    name,
182                    field_name,
183                    description,
184                    required,
185                    nullable: false,
186                    deprecated: schema.deprecated,
187                    r#type: PropertyType::Nested(Box::new(composite)),
188                })
189            }
190            OpenApiType {
191                r#type: Some("object"),
192                properties: None,
193                ..
194            } => Some(Self {
195                field_name,
196                name,
197                description,
198                required,
199                nullable: false,
200                r#type: PropertyType::Any,
201                deprecated: schema.deprecated,
202            }),
203            OpenApiType {
204                r#type: Some("object"),
205                ..
206            } => {
207                let obj_name = name.to_upper_camel_case();
208                Some(Self {
209                    r#type: PropertyType::Nested(Box::new(Object::from_schema_object(
210                        &obj_name,
211                        schema,
212                        schemas,
213                        warnings.child(&obj_name),
214                    ))),
215                    name,
216                    field_name,
217                    description,
218                    required,
219                    deprecated: schema.deprecated,
220                    nullable: false,
221                })
222            }
223            OpenApiType {
224                ref_path: Some(path),
225                ..
226            } => Some(Self {
227                name,
228                field_name,
229                description,
230                r#type: PropertyType::Ref((*path).to_owned()),
231                required,
232                deprecated: schema.deprecated,
233                nullable: false,
234            }),
235            OpenApiType {
236                r#type: Some("array"),
237                items: Some(items),
238                ..
239            } => {
240                let inner = Self::from_schema(&name, required, items, schemas, warnings)?;
241
242                Some(Self {
243                    name,
244                    field_name,
245                    description,
246                    required,
247                    nullable: false,
248                    deprecated: schema.deprecated,
249                    r#type: PropertyType::Array(Box::new(inner.r#type)),
250                })
251            }
252            OpenApiType {
253                r#type: Some(_), ..
254            } => {
255                let prim = match (schema.r#type, schema.format) {
256                    (Some("integer"), Some("int32")) => PrimitiveType::I32,
257                    (Some("integer"), Some("int64")) => PrimitiveType::I64,
258                    (Some("number"), /* Some("float") */ _) | (_, Some("float")) => {
259                        PrimitiveType::Float
260                    }
261                    (Some("string"), None) => PrimitiveType::String,
262                    (Some("string"), Some("date")) => PrimitiveType::DateTime,
263                    (Some("boolean"), None) => PrimitiveType::Bool,
264                    _ => return None,
265                };
266
267                Some(Self {
268                    name,
269                    field_name,
270                    description,
271                    required,
272                    nullable: false,
273                    deprecated: schema.deprecated,
274                    r#type: PropertyType::Primitive(prim),
275                })
276            }
277            _ => {
278                warnings.push("Could not resolve property type");
279                None
280            }
281        }
282    }
283
284    pub fn codegen(
285        &self,
286        namespace: &mut ObjectNamespace,
287        resolved: &ResolvedSchema,
288    ) -> Option<TokenStream> {
289        let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]});
290
291        let name = &self.name;
292        let (name, serde_attr) = match name.as_str() {
293            // https://doc.rust-lang.org/reference/keywords.html#r-lex.keywords
294            "as" => (format_ident!("r#as"), None),
295            "break" => (format_ident!("r#break"), None),
296            "const" => (format_ident!("r#const"), None),
297            "continue" => (format_ident!("r#continue"), None),
298            "crate" => (format_ident!("r#crate"), None),
299            "else" => (format_ident!("r#else"), None),
300            "enum" => (format_ident!("r#enum"), None),
301            "extern" => (format_ident!("r#extern"), None),
302            "false" => (format_ident!("r#false"), None),
303            "fn" => (format_ident!("r#fn"), None),
304            "for" => (format_ident!("r#for"), None),
305            "if" => (format_ident!("r#if"), None),
306            "impl" => (format_ident!("r#impl"), None),
307            "in" => (format_ident!("r#in"), None),
308            "let" => (format_ident!("r#let"), None),
309            "loop" => (format_ident!("r#loop"), None),
310            "match" => (format_ident!("r#match"), None),
311            "mod" => (format_ident!("r#mod"), None),
312            "move" => (format_ident!("r#move"), None),
313            "mut" => (format_ident!("r#mut"), None),
314            "pub" => (format_ident!("r#pub"), None),
315            "ref" => (format_ident!("r#ref"), None),
316            "return" => (format_ident!("r#return"), None),
317            "self" => (format_ident!("r#self"), None),
318            "Self" => (format_ident!("r#Self"), None),
319            "static" => (format_ident!("r#static"), None),
320            "struct" => (format_ident!("r#struct"), None),
321            "super" => (format_ident!("r#super"), None),
322            "trait" => (format_ident!("r#trait"), None),
323            "true" => (format_ident!("r#true"), None),
324            "type" => (format_ident!("r#type"), None),
325            "unsafe" => (format_ident!("r#unsafe"), None),
326            "use" => (format_ident!("r#use"), None),
327            "where" => (format_ident!("r#where"), None),
328            "while" => (format_ident!("r#while"), None),
329            "async" => (format_ident!("r#async"), None),
330            "await" => (format_ident!("r#await"), None),
331            "dyn" => (format_ident!("r#dyn"), None),
332            "abstract" => (format_ident!("r#abstract"), None),
333            "become" => (format_ident!("r#become"), None),
334            "box" => (format_ident!("r#box"), None),
335            "do" => (format_ident!("r#do"), None),
336            "final" => (format_ident!("r#final"), None),
337            "macro" => (format_ident!("r#macro"), None),
338            "override" => (format_ident!("r#override"), None),
339            "priv" => (format_ident!("r#priv"), None),
340            "typeof" => (format_ident!("r#typeof"), None),
341            "unsized" => (format_ident!("r#unsized"), None),
342            "virtual" => (format_ident!("r#virtual"), None),
343            "yield" => (format_ident!("r#yield"), None),
344            "try" => (format_ident!("r#try"), None),
345            "gen" => (format_ident!("r#gen"), None),
346
347            name if name != self.field_name => (
348                format_ident!("{}", self.field_name),
349                Some(quote! { #[serde(rename = #name)]}),
350            ),
351            _ => (format_ident!("{}", self.field_name), None),
352        };
353
354        let ty_inner = self.r#type.codegen(namespace, resolved)?;
355
356        let ty = if !self.required || self.nullable {
357            quote! { Option<#ty_inner> }
358        } else {
359            ty_inner
360        };
361
362        let deprecated = self.deprecated.then(|| {
363            let note = self.description.as_ref().map(|d| quote! { note = #d });
364
365            quote! {
366                #[deprecated(#note)]
367            }
368        });
369
370        Some(quote! {
371            #desc
372            #deprecated
373            #serde_attr
374            pub #name: #ty
375        })
376    }
377}
378
379#[derive(Debug, Clone, PartialEq, Eq, Default)]
380pub struct Object {
381    pub name: String,
382    pub description: Option<String>,
383    pub properties: IndexMap<String, Property>,
384}
385
386impl Object {
387    pub fn from_schema_object(
388        name: &str,
389        schema: &OpenApiType,
390        schemas: &IndexMap<&str, OpenApiType>,
391        warnings: WarningReporter,
392    ) -> Self {
393        let mut result = Object {
394            name: name.to_owned(),
395            description: schema.description.as_deref().map(ToOwned::to_owned),
396            ..Default::default()
397        };
398
399        let Some(props) = &schema.properties else {
400            warnings.push("Missing properties");
401            return result;
402        };
403
404        let required = schema.required.clone().unwrap_or_default();
405
406        for (prop_name, prop) in props {
407            let Some(prop) = Property::from_schema(
408                prop_name,
409                required.contains(prop_name),
410                prop,
411                schemas,
412                warnings.child(prop_name),
413            ) else {
414                continue;
415            };
416
417            let field_name = prop.field_name.clone();
418
419            let entry = result.properties.entry(field_name.clone());
420            if let Entry::Occupied(mut entry) = entry {
421                let other_name = entry.get().name.clone();
422                warnings.push(format!(
423                    "Property name collision: {other_name} and {field_name}"
424                ));
425                // deprioritise kebab and camelcase
426                if other_name.contains('-')
427                    || other_name
428                        .chars()
429                        .filter(|c| c.is_alphabetic())
430                        .any(|c| c.is_ascii_uppercase())
431                {
432                    entry.insert(prop);
433                }
434            } else {
435                entry.insert_entry(prop);
436            }
437        }
438
439        result
440    }
441
442    pub fn from_all_of(
443        name: &str,
444        types: &[OpenApiType],
445        schemas: &IndexMap<&str, OpenApiType>,
446        warnings: WarningReporter,
447    ) -> Self {
448        let mut result = Self {
449            name: name.to_owned(),
450            ..Default::default()
451        };
452
453        for r#type in types {
454            let r#type = if let OpenApiType {
455                ref_path: Some(ref_path),
456                ..
457            } = r#type
458            {
459                let Some(name) = ref_path.strip_prefix("#/components/schemas/") else {
460                    warnings.push(format!("Malformed ref {ref_path}"));
461                    continue;
462                };
463                let Some(schema) = schemas.get(name) else {
464                    warnings.push(format!("Missing schema for ref {name}"));
465                    continue;
466                };
467                schema
468            } else {
469                r#type
470            };
471            let obj = Self::from_schema_object(name, r#type, schemas, warnings.child("variant"));
472
473            result.description = result.description.or(obj.description);
474            result.properties.extend(obj.properties);
475        }
476
477        result
478    }
479
480    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
481        let doc = self.description.as_ref().map(|d| {
482            quote! {
483                #[doc = #d]
484            }
485        });
486
487        let mut namespace = ObjectNamespace {
488            object: self,
489            ident: None,
490            elements: Vec::default(),
491        };
492
493        let mut props = Vec::with_capacity(self.properties.len());
494        for (_, prop) in &self.properties {
495            props.push(prop.codegen(&mut namespace, resolved)?);
496        }
497
498        let name = format_ident!("{}", self.name);
499        let ns = namespace.codegen();
500
501        Some(quote! {
502            #ns
503
504            #doc
505            #[derive(Debug, Clone, PartialEq, serde::Deserialize)]
506            pub struct #name {
507                #(#props),*
508            }
509        })
510    }
511}
512
513pub struct ObjectNamespace<'o> {
514    object: &'o Object,
515    ident: Option<Ident>,
516    elements: Vec<TokenStream>,
517}
518
519impl ObjectNamespace<'_> {
520    pub fn get_ident(&mut self) -> Ident {
521        self.ident
522            .get_or_insert_with(|| {
523                let name = self.object.name.to_snake_case();
524                format_ident!("{name}")
525            })
526            .clone()
527    }
528
529    pub fn push_element(&mut self, el: TokenStream) {
530        self.elements.push(el);
531    }
532
533    pub fn codegen(mut self) -> Option<TokenStream> {
534        if self.elements.is_empty() {
535            None
536        } else {
537            let ident = self.get_ident();
538            let elements = self.elements;
539            Some(quote! {
540                pub mod #ident {
541                    #(#elements)*
542                }
543            })
544        }
545    }
546}
547
548#[cfg(test)]
549mod test {
550    use super::*;
551
552    use crate::openapi::schema::test::get_schema;
553
554    #[test]
555    fn resolve_objects() {
556        let schema = get_schema();
557
558        let mut objects = 0;
559        let mut unresolved = vec![];
560
561        for (name, desc) in &schema.components.schemas {
562            if desc.r#type == Some("object") {
563                objects += 1;
564                let reporter = WarningReporter::new();
565                Object::from_schema_object(
566                    name,
567                    desc,
568                    &schema.components.schemas,
569                    reporter.clone(),
570                );
571                if !reporter.is_empty() {
572                    unresolved.push(name);
573                }
574            }
575        }
576
577        if !unresolved.is_empty() {
578            panic!(
579                "Failed to resolve {}/{} objects. Could not resolve [{}]",
580                unresolved.len(),
581                objects,
582                unresolved
583                    .into_iter()
584                    .map(|u| format!("`{u}`"))
585                    .collect::<Vec<_>>()
586                    .join(", ")
587            )
588        }
589    }
590}