torn_api_codegen/model/
object.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use indexmap::{map::Entry, IndexMap};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::Ident;
6
7use crate::openapi::r#type::OpenApiType;
8
9use super::{r#enum::Enum, ResolvedSchema, WarningReporter};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PrimitiveType {
13    Bool,
14    I32,
15    I64,
16    String,
17    Float,
18    DateTime,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum PropertyType {
23    Primitive(PrimitiveType),
24    Ref(String),
25    Enum(Enum),
26    Nested(Box<Object>),
27    Array(Box<PropertyType>),
28    Any,
29}
30
31impl PropertyType {
32    pub fn codegen(
33        &self,
34        namespace: &mut ObjectNamespace,
35        resolved: &ResolvedSchema,
36    ) -> Option<TokenStream> {
37        match self {
38            Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()),
39            Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()),
40            Self::Primitive(PrimitiveType::I64) => Some(format_ident!("i64").into_token_stream()),
41            Self::Primitive(PrimitiveType::String) => {
42                Some(format_ident!("String").into_token_stream())
43            }
44            Self::Primitive(PrimitiveType::DateTime) => {
45                Some(quote! { chrono::DateTime<chrono::Utc> })
46            }
47            Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()),
48            Self::Ref(path) => {
49                let name = path.strip_prefix("#/components/schemas/")?;
50                let name = format_ident!("{name}");
51
52                Some(quote! { crate::models::#name })
53            }
54            Self::Enum(r#enum) => {
55                let code = r#enum.codegen(resolved)?;
56                namespace.push_element(code);
57
58                let ns = namespace.get_ident();
59                let name = format_ident!("{}", r#enum.name);
60
61                Some(quote! {
62                    #ns::#name
63                })
64            }
65            Self::Array(array) => {
66                let inner_ty = array.codegen(namespace, resolved)?;
67
68                Some(quote! {
69                    Vec<#inner_ty>
70                })
71            }
72            Self::Nested(nested) => {
73                let code = nested.codegen(resolved)?;
74                namespace.push_element(code);
75
76                let ns = namespace.get_ident();
77                let name = format_ident!("{}", nested.name);
78
79                Some(quote! {
80                    #ns::#name
81                })
82            }
83            Self::Any => Some(quote! {
84                serde_json::Value
85            }),
86        }
87    }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct Property {
92    pub field_name: String,
93    pub name: String,
94    pub description: Option<String>,
95    pub required: bool,
96    pub nullable: bool,
97    pub r#type: PropertyType,
98    pub deprecated: bool,
99}
100
101impl Property {
102    pub fn from_schema(
103        name: &str,
104        required: bool,
105        schema: &OpenApiType,
106        schemas: &IndexMap<&str, OpenApiType>,
107        warnings: WarningReporter,
108    ) -> Option<Self> {
109        let name = name.to_owned();
110        let field_name = name.to_snake_case();
111        let description = schema.description.as_deref().map(ToOwned::to_owned);
112
113        match schema {
114            OpenApiType {
115                r#enum: Some(_), ..
116            } => {
117                let Some(r#enum) = Enum::from_schema(
118                    &name.clone().to_upper_camel_case(),
119                    schema,
120                    warnings.clone(),
121                ) else {
122                    warnings.push("Failed to create enum");
123                    return None;
124                };
125                Some(Self {
126                    r#type: PropertyType::Enum(r#enum),
127                    name,
128                    field_name,
129                    description,
130                    required,
131                    deprecated: schema.deprecated,
132                    nullable: false,
133                })
134            }
135            OpenApiType {
136                one_of: Some(types),
137                ..
138            } => match types.as_slice() {
139                [left, OpenApiType {
140                    r#type: Some("null"),
141                    ..
142                }] => {
143                    let mut inner = Self::from_schema(&name, required, left, schemas, warnings)?;
144                    inner.nullable = true;
145                    Some(inner)
146                }
147                [left @ .., OpenApiType {
148                    r#type: Some("null"),
149                    ..
150                }] => {
151                    let rest = OpenApiType {
152                        one_of: Some(left.to_owned()),
153                        ..schema.clone()
154                    };
155                    let mut inner = Self::from_schema(&name, required, &rest, schemas, warnings)?;
156                    inner.nullable = true;
157                    Some(inner)
158                }
159                cases => {
160                    let Some(r#enum) =
161                        Enum::from_one_of(&name.to_upper_camel_case(), cases, warnings.clone())
162                    else {
163                        warnings.push("Failed to create oneOf enum");
164                        return None;
165                    };
166
167                    Some(Self {
168                        name,
169                        field_name,
170                        description,
171                        required,
172                        nullable: false,
173                        deprecated: schema.deprecated,
174                        r#type: PropertyType::Enum(r#enum),
175                    })
176                }
177            },
178            OpenApiType {
179                all_of: Some(types),
180                ..
181            } => {
182                let obj_name = name.to_upper_camel_case();
183                let composite =
184                    Object::from_all_of(&obj_name, types, schemas, warnings.child(&obj_name));
185                Some(Self {
186                    name,
187                    field_name,
188                    description,
189                    required,
190                    nullable: false,
191                    deprecated: schema.deprecated,
192                    r#type: PropertyType::Nested(Box::new(composite)),
193                })
194            }
195            OpenApiType {
196                r#type: Some("object"),
197                properties: None,
198                ..
199            } => Some(Self {
200                field_name,
201                name,
202                description,
203                required,
204                nullable: false,
205                r#type: PropertyType::Any,
206                deprecated: schema.deprecated,
207            }),
208            OpenApiType {
209                r#type: Some("object"),
210                ..
211            } => {
212                let obj_name = name.to_upper_camel_case();
213                Some(Self {
214                    r#type: PropertyType::Nested(Box::new(Object::from_schema_object(
215                        &obj_name,
216                        schema,
217                        schemas,
218                        warnings.child(&obj_name),
219                    ))),
220                    name,
221                    field_name,
222                    description,
223                    required,
224                    deprecated: schema.deprecated,
225                    nullable: false,
226                })
227            }
228            OpenApiType {
229                ref_path: Some(path),
230                ..
231            } => Some(Self {
232                name,
233                field_name,
234                description,
235                r#type: PropertyType::Ref((*path).to_owned()),
236                required,
237                deprecated: schema.deprecated,
238                nullable: false,
239            }),
240            OpenApiType {
241                r#type: Some("array"),
242                items: Some(items),
243                ..
244            } => {
245                let inner = Self::from_schema(&name, required, items, schemas, warnings)?;
246
247                Some(Self {
248                    name,
249                    field_name,
250                    description,
251                    required,
252                    nullable: false,
253                    deprecated: schema.deprecated,
254                    r#type: PropertyType::Array(Box::new(inner.r#type)),
255                })
256            }
257            OpenApiType {
258                r#type: Some(_), ..
259            } => {
260                let prim = match (schema.r#type, schema.format) {
261                    (Some("integer"), Some("int32")) => PrimitiveType::I32,
262                    (Some("integer"), Some("int64")) => PrimitiveType::I64,
263                    (Some("number"), /* Some("float") */ _) | (_, Some("float")) => {
264                        PrimitiveType::Float
265                    }
266                    (Some("string"), None) => PrimitiveType::String,
267                    (Some("string"), Some("date")) => PrimitiveType::DateTime,
268                    (Some("boolean"), None) => PrimitiveType::Bool,
269                    _ => return None,
270                };
271
272                Some(Self {
273                    name,
274                    field_name,
275                    description,
276                    required,
277                    nullable: false,
278                    deprecated: schema.deprecated,
279                    r#type: PropertyType::Primitive(prim),
280                })
281            }
282            _ => {
283                warnings.push("Could not resolve property type");
284                None
285            }
286        }
287    }
288
289    pub fn codegen(
290        &self,
291        namespace: &mut ObjectNamespace,
292        resolved: &ResolvedSchema,
293    ) -> Option<TokenStream> {
294        let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]});
295
296        let name = &self.name;
297        let (name, serde_attr) = match name.as_str() {
298            // https://doc.rust-lang.org/reference/keywords.html#r-lex.keywords
299            "as" => (format_ident!("r#as"), None),
300            "break" => (format_ident!("r#break"), None),
301            "const" => (format_ident!("r#const"), None),
302            "continue" => (format_ident!("r#continue"), None),
303            "crate" => (format_ident!("r#crate"), None),
304            "else" => (format_ident!("r#else"), None),
305            "enum" => (format_ident!("r#enum"), None),
306            "extern" => (format_ident!("r#extern"), None),
307            "false" => (format_ident!("r#false"), None),
308            "fn" => (format_ident!("r#fn"), None),
309            "for" => (format_ident!("r#for"), None),
310            "if" => (format_ident!("r#if"), None),
311            "impl" => (format_ident!("r#impl"), None),
312            "in" => (format_ident!("r#in"), None),
313            "let" => (format_ident!("r#let"), None),
314            "loop" => (format_ident!("r#loop"), None),
315            "match" => (format_ident!("r#match"), None),
316            "mod" => (format_ident!("r#mod"), None),
317            "move" => (format_ident!("r#move"), None),
318            "mut" => (format_ident!("r#mut"), None),
319            "pub" => (format_ident!("r#pub"), None),
320            "ref" => (format_ident!("r#ref"), None),
321            "return" => (format_ident!("r#return"), None),
322            "self" => (format_ident!("r#self"), None),
323            "Self" => (format_ident!("r#Self"), None),
324            "static" => (format_ident!("r#static"), None),
325            "struct" => (format_ident!("r#struct"), None),
326            "super" => (format_ident!("r#super"), None),
327            "trait" => (format_ident!("r#trait"), None),
328            "true" => (format_ident!("r#true"), None),
329            "type" => (format_ident!("r#type"), None),
330            "unsafe" => (format_ident!("r#unsafe"), None),
331            "use" => (format_ident!("r#use"), None),
332            "where" => (format_ident!("r#where"), None),
333            "while" => (format_ident!("r#while"), None),
334            "async" => (format_ident!("r#async"), None),
335            "await" => (format_ident!("r#await"), None),
336            "dyn" => (format_ident!("r#dyn"), None),
337            "abstract" => (format_ident!("r#abstract"), None),
338            "become" => (format_ident!("r#become"), None),
339            "box" => (format_ident!("r#box"), None),
340            "do" => (format_ident!("r#do"), None),
341            "final" => (format_ident!("r#final"), None),
342            "macro" => (format_ident!("r#macro"), None),
343            "override" => (format_ident!("r#override"), None),
344            "priv" => (format_ident!("r#priv"), None),
345            "typeof" => (format_ident!("r#typeof"), None),
346            "unsized" => (format_ident!("r#unsized"), None),
347            "virtual" => (format_ident!("r#virtual"), None),
348            "yield" => (format_ident!("r#yield"), None),
349            "try" => (format_ident!("r#try"), None),
350            "gen" => (format_ident!("r#gen"), None),
351
352            name if name != self.field_name => (
353                format_ident!("{}", self.field_name),
354                Some(quote! { #[serde(rename = #name)]}),
355            ),
356            _ => (format_ident!("{}", self.field_name), None),
357        };
358
359        let ty_inner = self.r#type.codegen(namespace, resolved)?;
360
361        let ty = if !self.required || self.nullable {
362            quote! { Option<#ty_inner> }
363        } else {
364            ty_inner
365        };
366
367        let deprecated = self.deprecated.then(|| {
368            let note = self.description.as_ref().map(|d| quote! { note = #d });
369
370            quote! {
371                #[deprecated(#note)]
372            }
373        });
374
375        Some(quote! {
376            #desc
377            #deprecated
378            #serde_attr
379            pub #name: #ty
380        })
381    }
382}
383
384#[derive(Debug, Clone, PartialEq, Eq, Default)]
385pub struct Object {
386    pub name: String,
387    pub description: Option<String>,
388    pub properties: IndexMap<String, Property>,
389}
390
391impl Object {
392    pub fn from_schema_object(
393        name: &str,
394        schema: &OpenApiType,
395        schemas: &IndexMap<&str, OpenApiType>,
396        warnings: WarningReporter,
397    ) -> Self {
398        let mut result = Object {
399            name: name.to_owned(),
400            description: schema.description.as_deref().map(ToOwned::to_owned),
401            ..Default::default()
402        };
403
404        let Some(props) = &schema.properties else {
405            warnings.push("Missing properties");
406            return result;
407        };
408
409        let required = schema.required.clone().unwrap_or_default();
410
411        for (prop_name, prop) in props {
412            let Some(prop) = Property::from_schema(
413                prop_name,
414                required.contains(prop_name),
415                prop,
416                schemas,
417                warnings.child(prop_name),
418            ) else {
419                continue;
420            };
421
422            let field_name = prop.field_name.clone();
423
424            let entry = result.properties.entry(field_name.clone());
425            if let Entry::Occupied(mut entry) = entry {
426                let other_name = entry.get().name.clone();
427                warnings.push(format!(
428                    "Property name collision: {other_name} and {field_name}"
429                ));
430                // deprioritise kebab and camelcase
431                if other_name.contains('-')
432                    || other_name
433                        .chars()
434                        .filter(|c| c.is_alphabetic())
435                        .any(|c| c.is_ascii_uppercase())
436                {
437                    entry.insert(prop);
438                }
439            } else {
440                entry.insert_entry(prop);
441            }
442        }
443
444        result
445    }
446
447    pub fn from_all_of(
448        name: &str,
449        types: &[OpenApiType],
450        schemas: &IndexMap<&str, OpenApiType>,
451        warnings: WarningReporter,
452    ) -> Self {
453        let mut result = Self {
454            name: name.to_owned(),
455            ..Default::default()
456        };
457
458        for r#type in types {
459            let r#type = if let OpenApiType {
460                ref_path: Some(ref_path),
461                ..
462            } = r#type
463            {
464                let Some(name) = ref_path.strip_prefix("#/components/schemas/") else {
465                    warnings.push(format!("Malformed ref {ref_path}"));
466                    continue;
467                };
468                let Some(schema) = schemas.get(name) else {
469                    warnings.push(format!("Missing schema for ref {name}"));
470                    continue;
471                };
472                schema
473            } else {
474                r#type
475            };
476            let obj = if let Some(types) = &r#type.all_of {
477                Self::from_all_of(name, types, schemas, warnings.child("variant"))
478            } else {
479                Self::from_schema_object(name, r#type, schemas, warnings.child("variant"))
480            };
481
482            result.description = result.description.or(obj.description);
483            result.properties.extend(obj.properties);
484        }
485
486        result
487    }
488
489    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
490        let doc = self.description.as_ref().map(|d| {
491            quote! {
492                #[doc = #d]
493            }
494        });
495
496        let mut namespace = ObjectNamespace {
497            object: self,
498            ident: None,
499            elements: Vec::default(),
500        };
501
502        let mut props = Vec::with_capacity(self.properties.len());
503        for (_, prop) in &self.properties {
504            props.push(prop.codegen(&mut namespace, resolved)?);
505        }
506
507        let name = format_ident!("{}", self.name);
508        let ns = namespace.codegen();
509
510        Some(quote! {
511            #ns
512
513            #doc
514            #[derive(Debug, Clone, PartialEq, serde::Deserialize)]
515            pub struct #name {
516                #(#props),*
517            }
518        })
519    }
520}
521
522pub struct ObjectNamespace<'o> {
523    object: &'o Object,
524    ident: Option<Ident>,
525    elements: Vec<TokenStream>,
526}
527
528impl ObjectNamespace<'_> {
529    pub fn get_ident(&mut self) -> Ident {
530        self.ident
531            .get_or_insert_with(|| {
532                let name = self.object.name.to_snake_case();
533                format_ident!("{name}")
534            })
535            .clone()
536    }
537
538    pub fn push_element(&mut self, el: TokenStream) {
539        self.elements.push(el);
540    }
541
542    pub fn codegen(mut self) -> Option<TokenStream> {
543        if self.elements.is_empty() {
544            None
545        } else {
546            let ident = self.get_ident();
547            let elements = self.elements;
548            Some(quote! {
549                pub mod #ident {
550                    #(#elements)*
551                }
552            })
553        }
554    }
555}
556
557#[cfg(test)]
558mod test {
559    use super::*;
560
561    use crate::openapi::schema::test::get_schema;
562
563    #[test]
564    fn resolve_objects() {
565        let schema = get_schema();
566
567        let mut objects = 0;
568        let mut unresolved = vec![];
569
570        for (name, desc) in &schema.components.schemas {
571            if desc.r#type == Some("object") {
572                objects += 1;
573                let reporter = WarningReporter::new();
574                Object::from_schema_object(
575                    name,
576                    desc,
577                    &schema.components.schemas,
578                    reporter.clone(),
579                );
580                if !reporter.is_empty() {
581                    unresolved.push(name);
582                }
583            }
584        }
585
586        if !unresolved.is_empty() {
587            panic!(
588                "Failed to resolve {}/{} objects. Could not resolve [{}]",
589                unresolved.len(),
590                objects,
591                unresolved
592                    .into_iter()
593                    .map(|u| format!("`{u}`"))
594                    .collect::<Vec<_>>()
595                    .join(", ")
596            )
597        }
598    }
599}