torn_api_codegen/model/
object.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use indexmap::{map::Entry, IndexMap};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::Ident;
6
7use crate::openapi::r#type::OpenApiType;
8
9use super::{r#enum::Enum, ResolvedSchema, WarningReporter};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PrimitiveType {
13    Bool,
14    I32,
15    I64,
16    String,
17    Float,
18    DateTime,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum PropertyType {
23    Primitive(PrimitiveType),
24    Ref(String),
25    Enum(Enum),
26    Nested(Box<Object>),
27    Array(Box<PropertyType>),
28}
29
30impl PropertyType {
31    pub fn codegen(
32        &self,
33        namespace: &mut ObjectNamespace,
34        resolved: &ResolvedSchema,
35    ) -> Option<TokenStream> {
36        match self {
37            Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()),
38            Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()),
39            Self::Primitive(PrimitiveType::I64) => Some(format_ident!("i64").into_token_stream()),
40            Self::Primitive(PrimitiveType::String) => {
41                Some(format_ident!("String").into_token_stream())
42            }
43            Self::Primitive(PrimitiveType::DateTime) => {
44                Some(quote! { chrono::DateTime<chrono::Utc> })
45            }
46            Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()),
47            Self::Ref(path) => {
48                let name = path.strip_prefix("#/components/schemas/")?;
49                let name = format_ident!("{name}");
50
51                Some(quote! { crate::models::#name })
52            }
53            Self::Enum(r#enum) => {
54                let code = r#enum.codegen(resolved)?;
55                namespace.push_element(code);
56
57                let ns = namespace.get_ident();
58                let name = format_ident!("{}", r#enum.name);
59
60                Some(quote! {
61                    #ns::#name
62                })
63            }
64            Self::Array(array) => {
65                let inner_ty = array.codegen(namespace, resolved)?;
66
67                Some(quote! {
68                    Vec<#inner_ty>
69                })
70            }
71            Self::Nested(nested) => {
72                let code = nested.codegen(resolved)?;
73                namespace.push_element(code);
74
75                let ns = namespace.get_ident();
76                let name = format_ident!("{}", nested.name);
77
78                Some(quote! {
79                    #ns::#name
80                })
81            }
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct Property {
88    pub field_name: String,
89    pub name: String,
90    pub description: Option<String>,
91    pub required: bool,
92    pub nullable: bool,
93    pub r#type: PropertyType,
94    pub deprecated: bool,
95}
96
97impl Property {
98    pub fn from_schema(
99        name: &str,
100        required: bool,
101        schema: &OpenApiType,
102        schemas: &IndexMap<&str, OpenApiType>,
103        warnings: WarningReporter,
104    ) -> Option<Self> {
105        let name = name.to_owned();
106        let field_name = name.to_snake_case();
107        let description = schema.description.as_deref().map(ToOwned::to_owned);
108
109        match schema {
110            OpenApiType {
111                r#enum: Some(_), ..
112            } => {
113                let Some(r#enum) = Enum::from_schema(&name.clone().to_upper_camel_case(), schema)
114                else {
115                    warnings.push("Failed to create enum");
116                    return None;
117                };
118                Some(Self {
119                    r#type: PropertyType::Enum(r#enum),
120                    name,
121                    field_name,
122                    description,
123                    required,
124                    deprecated: schema.deprecated,
125                    nullable: false,
126                })
127            }
128            OpenApiType {
129                one_of: Some(types),
130                ..
131            } => match types.as_slice() {
132                [left, OpenApiType {
133                    r#type: Some("null"),
134                    ..
135                }] => {
136                    let mut inner = Self::from_schema(&name, required, left, schemas, warnings)?;
137                    inner.nullable = true;
138                    Some(inner)
139                }
140                [left @ .., OpenApiType {
141                    r#type: Some("null"),
142                    ..
143                }] => {
144                    let rest = OpenApiType {
145                        one_of: Some(left.to_owned()),
146                        ..schema.clone()
147                    };
148                    let mut inner = Self::from_schema(&name, required, &rest, schemas, warnings)?;
149                    inner.nullable = true;
150                    Some(inner)
151                }
152                cases => {
153                    let Some(r#enum) = Enum::from_one_of(&name.to_upper_camel_case(), cases) else {
154                        warnings.push("Failed to create oneOf enum");
155                        return None;
156                    };
157
158                    Some(Self {
159                        name,
160                        field_name,
161                        description,
162                        required,
163                        nullable: false,
164                        deprecated: schema.deprecated,
165                        r#type: PropertyType::Enum(r#enum),
166                    })
167                }
168            },
169            OpenApiType {
170                all_of: Some(types),
171                ..
172            } => {
173                let obj_name = name.to_upper_camel_case();
174                let composite =
175                    Object::from_all_of(&obj_name, types, schemas, warnings.child(&obj_name));
176                Some(Self {
177                    name,
178                    field_name,
179                    description,
180                    required,
181                    nullable: false,
182                    deprecated: schema.deprecated,
183                    r#type: PropertyType::Nested(Box::new(composite)),
184                })
185            }
186            OpenApiType {
187                r#type: Some("object"),
188                ..
189            } => {
190                let obj_name = name.to_upper_camel_case();
191                Some(Self {
192                    r#type: PropertyType::Nested(Box::new(Object::from_schema_object(
193                        &obj_name,
194                        schema,
195                        schemas,
196                        warnings.child(&obj_name),
197                    ))),
198                    name,
199                    field_name,
200                    description,
201                    required,
202                    deprecated: schema.deprecated,
203                    nullable: false,
204                })
205            }
206            OpenApiType {
207                ref_path: Some(path),
208                ..
209            } => Some(Self {
210                name,
211                field_name,
212                description,
213                r#type: PropertyType::Ref((*path).to_owned()),
214                required,
215                deprecated: schema.deprecated,
216                nullable: false,
217            }),
218            OpenApiType {
219                r#type: Some("array"),
220                items: Some(items),
221                ..
222            } => {
223                let inner = Self::from_schema(&name, required, items, schemas, warnings)?;
224
225                Some(Self {
226                    name,
227                    field_name,
228                    description,
229                    required,
230                    nullable: false,
231                    deprecated: schema.deprecated,
232                    r#type: PropertyType::Array(Box::new(inner.r#type)),
233                })
234            }
235            OpenApiType {
236                r#type: Some(_), ..
237            } => {
238                let prim = match (schema.r#type, schema.format) {
239                    (Some("integer"), Some("int32")) => PrimitiveType::I32,
240                    (Some("integer"), Some("int64")) => PrimitiveType::I64,
241                    (Some("number"), /* Some("float") */ _) | (_, Some("float")) => {
242                        PrimitiveType::Float
243                    }
244                    (Some("string"), None) => PrimitiveType::String,
245                    (Some("string"), Some("date")) => PrimitiveType::DateTime,
246                    (Some("boolean"), None) => PrimitiveType::Bool,
247                    _ => return None,
248                };
249
250                Some(Self {
251                    name,
252                    field_name,
253                    description,
254                    required,
255                    nullable: false,
256                    deprecated: schema.deprecated,
257                    r#type: PropertyType::Primitive(prim),
258                })
259            }
260            _ => {
261                warnings.push("Could not resolve property type");
262                None
263            }
264        }
265    }
266
267    pub fn codegen(
268        &self,
269        namespace: &mut ObjectNamespace,
270        resolved: &ResolvedSchema,
271    ) -> Option<TokenStream> {
272        let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]});
273
274        let name = &self.name;
275        let (name, serde_attr) = match name.as_str() {
276            "type" => (format_ident!("r#type"), None),
277            name if name != self.field_name => (
278                format_ident!("{}", self.field_name),
279                Some(quote! { #[serde(rename = #name)]}),
280            ),
281            _ => (format_ident!("{}", self.field_name), None),
282        };
283
284        let ty_inner = self.r#type.codegen(namespace, resolved)?;
285
286        let ty = if !self.required || self.nullable {
287            quote! { Option<#ty_inner> }
288        } else {
289            ty_inner
290        };
291
292        let deprecated = self.deprecated.then(|| {
293            let note = self.description.as_ref().map(|d| quote! { note = #d });
294
295            quote! {
296                #[deprecated(#note)]
297            }
298        });
299
300        Some(quote! {
301            #desc
302            #deprecated
303            #serde_attr
304            pub #name: #ty
305        })
306    }
307}
308
309#[derive(Debug, Clone, PartialEq, Eq, Default)]
310pub struct Object {
311    pub name: String,
312    pub description: Option<String>,
313    pub properties: IndexMap<String, Property>,
314}
315
316impl Object {
317    pub fn from_schema_object(
318        name: &str,
319        schema: &OpenApiType,
320        schemas: &IndexMap<&str, OpenApiType>,
321        warnings: WarningReporter,
322    ) -> Self {
323        let mut result = Object {
324            name: name.to_owned(),
325            description: schema.description.as_deref().map(ToOwned::to_owned),
326            ..Default::default()
327        };
328
329        let Some(props) = &schema.properties else {
330            warnings.push("Missing properties");
331            return result;
332        };
333
334        let required = schema.required.clone().unwrap_or_default();
335
336        for (prop_name, prop) in props {
337            let Some(prop) = Property::from_schema(
338                prop_name,
339                required.contains(prop_name),
340                prop,
341                schemas,
342                warnings.child(prop_name),
343            ) else {
344                continue;
345            };
346
347            let field_name = prop.field_name.clone();
348
349            let entry = result.properties.entry(field_name.clone());
350            if let Entry::Occupied(mut entry) = entry {
351                let other_name = entry.get().name.clone();
352                warnings.push(format!(
353                    "Property name collision: {other_name} and {field_name}"
354                ));
355                // deprioritise kebab and camelcase
356                if other_name.contains('-')
357                    || other_name
358                        .chars()
359                        .filter(|c| c.is_alphabetic())
360                        .all(|c| c.is_ascii_lowercase())
361                {
362                    entry.insert(prop);
363                }
364            } else {
365                entry.insert_entry(prop);
366            }
367        }
368
369        result
370    }
371
372    pub fn from_all_of(
373        name: &str,
374        types: &[OpenApiType],
375        schemas: &IndexMap<&str, OpenApiType>,
376        warnings: WarningReporter,
377    ) -> Self {
378        let mut result = Self {
379            name: name.to_owned(),
380            ..Default::default()
381        };
382
383        for r#type in types {
384            let r#type = if let OpenApiType {
385                ref_path: Some(ref_path),
386                ..
387            } = r#type
388            {
389                let Some(name) = ref_path.strip_prefix("#/components/schemas/") else {
390                    warnings.push(format!("Malformed ref {ref_path}"));
391                    continue;
392                };
393                let Some(schema) = schemas.get(name) else {
394                    warnings.push(format!("Missing schema for ref {name}"));
395                    continue;
396                };
397                schema
398            } else {
399                r#type
400            };
401            let obj = Self::from_schema_object(name, r#type, schemas, warnings.child("variant"));
402
403            result.description = result.description.or(obj.description);
404            result.properties.extend(obj.properties);
405        }
406
407        result
408    }
409
410    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
411        let doc = self.description.as_ref().map(|d| {
412            quote! {
413                #[doc = #d]
414            }
415        });
416
417        let mut namespace = ObjectNamespace {
418            object: self,
419            ident: None,
420            elements: Vec::default(),
421        };
422
423        let mut props = Vec::with_capacity(self.properties.len());
424        for (_, prop) in &self.properties {
425            props.push(prop.codegen(&mut namespace, resolved)?);
426        }
427
428        let name = format_ident!("{}", self.name);
429        let ns = namespace.codegen();
430
431        Some(quote! {
432            #ns
433
434            #doc
435            #[derive(Debug, Clone, PartialEq, serde::Deserialize)]
436            pub struct #name {
437                #(#props),*
438            }
439        })
440    }
441}
442
443pub struct ObjectNamespace<'o> {
444    object: &'o Object,
445    ident: Option<Ident>,
446    elements: Vec<TokenStream>,
447}
448
449impl ObjectNamespace<'_> {
450    pub fn get_ident(&mut self) -> Ident {
451        self.ident
452            .get_or_insert_with(|| {
453                let name = self.object.name.to_snake_case();
454                format_ident!("{name}")
455            })
456            .clone()
457    }
458
459    pub fn push_element(&mut self, el: TokenStream) {
460        self.elements.push(el);
461    }
462
463    pub fn codegen(mut self) -> Option<TokenStream> {
464        if self.elements.is_empty() {
465            None
466        } else {
467            let ident = self.get_ident();
468            let elements = self.elements;
469            Some(quote! {
470                pub mod #ident {
471                    #(#elements)*
472                }
473            })
474        }
475    }
476}
477
478#[cfg(test)]
479mod test {
480    use super::*;
481
482    use crate::openapi::schema::test::get_schema;
483
484    #[test]
485    fn resolve_objects() {
486        let schema = get_schema();
487
488        let mut objects = 0;
489        let mut unresolved = vec![];
490
491        for (name, desc) in &schema.components.schemas {
492            if desc.r#type == Some("object") {
493                objects += 1;
494                let reporter = WarningReporter::new();
495                Object::from_schema_object(
496                    name,
497                    desc,
498                    &schema.components.schemas,
499                    reporter.clone(),
500                );
501                if !reporter.is_empty() {
502                    unresolved.push(name);
503                }
504            }
505        }
506
507        if !unresolved.is_empty() {
508            panic!(
509                "Failed to resolve {}/{} objects. Could not resolve [{}]",
510                unresolved.len(),
511                objects,
512                unresolved
513                    .into_iter()
514                    .map(|u| format!("`{u}`"))
515                    .collect::<Vec<_>>()
516                    .join(", ")
517            )
518        }
519    }
520}