torn_api_codegen/model/
object.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use indexmap::{map::Entry, IndexMap};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::Ident;
6
7use crate::openapi::r#type::OpenApiType;
8
9use super::{r#enum::Enum, ResolvedSchema, WarningReporter};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PrimitiveType {
13    Bool,
14    I32,
15    I64,
16    String,
17    Float,
18    DateTime,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum PropertyType {
23    Primitive(PrimitiveType),
24    Ref(String),
25    Enum(Enum),
26    Nested(Box<Object>),
27    Array(Box<PropertyType>),
28}
29
30impl PropertyType {
31    pub fn codegen(
32        &self,
33        namespace: &mut ObjectNamespace,
34        resolved: &ResolvedSchema,
35    ) -> Option<TokenStream> {
36        match self {
37            Self::Primitive(PrimitiveType::Bool) => Some(format_ident!("bool").into_token_stream()),
38            Self::Primitive(PrimitiveType::I32) => Some(format_ident!("i32").into_token_stream()),
39            Self::Primitive(PrimitiveType::I64) => Some(format_ident!("i64").into_token_stream()),
40            Self::Primitive(PrimitiveType::String) => {
41                Some(format_ident!("String").into_token_stream())
42            }
43            Self::Primitive(PrimitiveType::DateTime) => {
44                Some(quote! { chrono::DateTime<chrono::Utc> })
45            }
46            Self::Primitive(PrimitiveType::Float) => Some(format_ident!("f64").into_token_stream()),
47            Self::Ref(path) => {
48                let name = path.strip_prefix("#/components/schemas/")?;
49                let name = format_ident!("{name}");
50
51                Some(quote! { crate::models::#name })
52            }
53            Self::Enum(r#enum) => {
54                let code = r#enum.codegen(resolved)?;
55                namespace.push_element(code);
56
57                let ns = namespace.get_ident();
58                let name = format_ident!("{}", r#enum.name);
59
60                Some(quote! {
61                    #ns::#name
62                })
63            }
64            Self::Array(array) => {
65                let inner_ty = array.codegen(namespace, resolved)?;
66
67                Some(quote! {
68                    Vec<#inner_ty>
69                })
70            }
71            Self::Nested(nested) => {
72                let code = nested.codegen(resolved)?;
73                namespace.push_element(code);
74
75                let ns = namespace.get_ident();
76                let name = format_ident!("{}", nested.name);
77
78                Some(quote! {
79                    #ns::#name
80                })
81            }
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct Property {
88    pub field_name: String,
89    pub name: String,
90    pub description: Option<String>,
91    pub required: bool,
92    pub nullable: bool,
93    pub r#type: PropertyType,
94    pub deprecated: bool,
95}
96
97impl Property {
98    pub fn from_schema(
99        name: &str,
100        required: bool,
101        schema: &OpenApiType,
102        schemas: &IndexMap<&str, OpenApiType>,
103        warnings: WarningReporter,
104    ) -> Option<Self> {
105        let name = name.to_owned();
106        let field_name = name.to_snake_case();
107        let description = schema.description.as_deref().map(ToOwned::to_owned);
108
109        match schema {
110            OpenApiType {
111                r#enum: Some(_), ..
112            } => {
113                let Some(r#enum) = Enum::from_schema(&name.clone().to_upper_camel_case(), schema)
114                else {
115                    warnings.push("Failed to create enum");
116                    return None;
117                };
118                Some(Self {
119                    r#type: PropertyType::Enum(r#enum),
120                    name,
121                    field_name,
122                    description,
123                    required,
124                    deprecated: schema.deprecated,
125                    nullable: false,
126                })
127            }
128            OpenApiType {
129                one_of: Some(types),
130                ..
131            } => match types.as_slice() {
132                [left, OpenApiType {
133                    r#type: Some("null"),
134                    ..
135                }] => {
136                    let mut inner = Self::from_schema(&name, required, left, schemas, warnings)?;
137                    inner.nullable = true;
138                    Some(inner)
139                }
140                [left @ .., OpenApiType {
141                    r#type: Some("null"),
142                    ..
143                }] => {
144                    let rest = OpenApiType {
145                        one_of: Some(left.to_owned()),
146                        ..schema.clone()
147                    };
148                    let mut inner = Self::from_schema(&name, required, &rest, schemas, warnings)?;
149                    inner.nullable = true;
150                    Some(inner)
151                }
152                cases => {
153                    let Some(r#enum) = Enum::from_one_of(&name.to_upper_camel_case(), cases) else {
154                        warnings.push("Failed to create oneOf enum");
155                        return None;
156                    };
157
158                    Some(Self {
159                        name,
160                        field_name,
161                        description,
162                        required,
163                        nullable: false,
164                        deprecated: schema.deprecated,
165                        r#type: PropertyType::Enum(r#enum),
166                    })
167                }
168            },
169            OpenApiType {
170                all_of: Some(types),
171                ..
172            } => {
173                let obj_name = name.to_upper_camel_case();
174                let composite =
175                    Object::from_all_of(&obj_name, types, schemas, warnings.child(&obj_name));
176                Some(Self {
177                    name,
178                    field_name,
179                    description,
180                    required,
181                    nullable: false,
182                    deprecated: schema.deprecated,
183                    r#type: PropertyType::Nested(Box::new(composite)),
184                })
185            }
186            OpenApiType {
187                r#type: Some("object"),
188                ..
189            } => {
190                let obj_name = name.to_upper_camel_case();
191                Some(Self {
192                    r#type: PropertyType::Nested(Box::new(Object::from_schema_object(
193                        &obj_name,
194                        schema,
195                        schemas,
196                        warnings.child(&obj_name),
197                    ))),
198                    name,
199                    field_name,
200                    description,
201                    required,
202                    deprecated: schema.deprecated,
203                    nullable: false,
204                })
205            }
206            OpenApiType {
207                ref_path: Some(path),
208                ..
209            } => Some(Self {
210                name,
211                field_name,
212                description,
213                r#type: PropertyType::Ref((*path).to_owned()),
214                required,
215                deprecated: schema.deprecated,
216                nullable: false,
217            }),
218            OpenApiType {
219                r#type: Some("array"),
220                items: Some(items),
221                ..
222            } => {
223                let inner = Self::from_schema(&name, required, items, schemas, warnings)?;
224
225                Some(Self {
226                    name,
227                    field_name,
228                    description,
229                    required,
230                    nullable: false,
231                    deprecated: schema.deprecated,
232                    r#type: PropertyType::Array(Box::new(inner.r#type)),
233                })
234            }
235            OpenApiType {
236                r#type: Some(_), ..
237            } => {
238                let prim = match (schema.r#type, schema.format) {
239                    (Some("integer"), Some("int32")) => PrimitiveType::I32,
240                    (Some("integer"), Some("int64")) => PrimitiveType::I64,
241                    (Some("number"), /* Some("float") */ _) | (_, Some("float")) => {
242                        PrimitiveType::Float
243                    }
244                    (Some("string"), None) => PrimitiveType::String,
245                    (Some("string"), Some("date")) => PrimitiveType::DateTime,
246                    (Some("boolean"), None) => PrimitiveType::Bool,
247                    _ => return None,
248                };
249
250                Some(Self {
251                    name,
252                    field_name,
253                    description,
254                    required,
255                    nullable: false,
256                    deprecated: schema.deprecated,
257                    r#type: PropertyType::Primitive(prim),
258                })
259            }
260            _ => {
261                warnings.push("Could not resolve property type");
262                None
263            }
264        }
265    }
266
267    pub fn codegen(
268        &self,
269        namespace: &mut ObjectNamespace,
270        resolved: &ResolvedSchema,
271    ) -> Option<TokenStream> {
272        let desc = self.description.as_ref().map(|d| quote! { #[doc = #d]});
273
274        let name = &self.name;
275        let (name, serde_attr) = match name.as_str() {
276            // https://doc.rust-lang.org/reference/keywords.html#r-lex.keywords
277            "as" => (format_ident!("r#as"), None),
278            "break" => (format_ident!("r#break"), None),
279            "const" => (format_ident!("r#const"), None),
280            "continue" => (format_ident!("r#continue"), None),
281            "crate" => (format_ident!("r#crate"), None),
282            "else" => (format_ident!("r#else"), None),
283            "enum" => (format_ident!("r#enum"), None),
284            "extern" => (format_ident!("r#extern"), None),
285            "false" => (format_ident!("r#false"), None),
286            "fn" => (format_ident!("r#fn"), None),
287            "for" => (format_ident!("r#for"), None),
288            "if" => (format_ident!("r#if"), None),
289            "impl" => (format_ident!("r#impl"), None),
290            "in" => (format_ident!("r#in"), None),
291            "let" => (format_ident!("r#let"), None),
292            "loop" => (format_ident!("r#loop"), None),
293            "match" => (format_ident!("r#match"), None),
294            "mod" => (format_ident!("r#mod"), None),
295            "move" => (format_ident!("r#move"), None),
296            "mut" => (format_ident!("r#mut"), None),
297            "pub" => (format_ident!("r#pub"), None),
298            "ref" => (format_ident!("r#ref"), None),
299            "return" => (format_ident!("r#return"), None),
300            "self" => (format_ident!("r#self"), None),
301            "Self" => (format_ident!("r#Self"), None),
302            "static" => (format_ident!("r#static"), None),
303            "struct" => (format_ident!("r#struct"), None),
304            "super" => (format_ident!("r#super"), None),
305            "trait" => (format_ident!("r#trait"), None),
306            "true" => (format_ident!("r#true"), None),
307            "type" => (format_ident!("r#type"), None),
308            "unsafe" => (format_ident!("r#unsafe"), None),
309            "use" => (format_ident!("r#use"), None),
310            "where" => (format_ident!("r#where"), None),
311            "while" => (format_ident!("r#while"), None),
312            "async" => (format_ident!("r#async"), None),
313            "await" => (format_ident!("r#await"), None),
314            "dyn" => (format_ident!("r#dyn"), None),
315            "abstract" => (format_ident!("r#abstract"), None),
316            "become" => (format_ident!("r#become"), None),
317            "box" => (format_ident!("r#box"), None),
318            "do" => (format_ident!("r#do"), None),
319            "final" => (format_ident!("r#final"), None),
320            "macro" => (format_ident!("r#macro"), None),
321            "override" => (format_ident!("r#override"), None),
322            "priv" => (format_ident!("r#priv"), None),
323            "typeof" => (format_ident!("r#typeof"), None),
324            "unsized" => (format_ident!("r#unsized"), None),
325            "virtual" => (format_ident!("r#virtual"), None),
326            "yield" => (format_ident!("r#yield"), None),
327            "try" => (format_ident!("r#try"), None),
328            "gen" => (format_ident!("r#gen"), None),
329
330            name if name != self.field_name => (
331                format_ident!("{}", self.field_name),
332                Some(quote! { #[serde(rename = #name)]}),
333            ),
334            _ => (format_ident!("{}", self.field_name), None),
335        };
336
337        let ty_inner = self.r#type.codegen(namespace, resolved)?;
338
339        let ty = if !self.required || self.nullable {
340            quote! { Option<#ty_inner> }
341        } else {
342            ty_inner
343        };
344
345        let deprecated = self.deprecated.then(|| {
346            let note = self.description.as_ref().map(|d| quote! { note = #d });
347
348            quote! {
349                #[deprecated(#note)]
350            }
351        });
352
353        Some(quote! {
354            #desc
355            #deprecated
356            #serde_attr
357            pub #name: #ty
358        })
359    }
360}
361
362#[derive(Debug, Clone, PartialEq, Eq, Default)]
363pub struct Object {
364    pub name: String,
365    pub description: Option<String>,
366    pub properties: IndexMap<String, Property>,
367}
368
369impl Object {
370    pub fn from_schema_object(
371        name: &str,
372        schema: &OpenApiType,
373        schemas: &IndexMap<&str, OpenApiType>,
374        warnings: WarningReporter,
375    ) -> Self {
376        let mut result = Object {
377            name: name.to_owned(),
378            description: schema.description.as_deref().map(ToOwned::to_owned),
379            ..Default::default()
380        };
381
382        let Some(props) = &schema.properties else {
383            warnings.push("Missing properties");
384            return result;
385        };
386
387        let required = schema.required.clone().unwrap_or_default();
388
389        for (prop_name, prop) in props {
390            let Some(prop) = Property::from_schema(
391                prop_name,
392                required.contains(prop_name),
393                prop,
394                schemas,
395                warnings.child(prop_name),
396            ) else {
397                continue;
398            };
399
400            let field_name = prop.field_name.clone();
401
402            let entry = result.properties.entry(field_name.clone());
403            if let Entry::Occupied(mut entry) = entry {
404                let other_name = entry.get().name.clone();
405                warnings.push(format!(
406                    "Property name collision: {other_name} and {field_name}"
407                ));
408                // deprioritise kebab and camelcase
409                if other_name.contains('-')
410                    || other_name
411                        .chars()
412                        .filter(|c| c.is_alphabetic())
413                        .all(|c| c.is_ascii_lowercase())
414                {
415                    entry.insert(prop);
416                }
417            } else {
418                entry.insert_entry(prop);
419            }
420        }
421
422        result
423    }
424
425    pub fn from_all_of(
426        name: &str,
427        types: &[OpenApiType],
428        schemas: &IndexMap<&str, OpenApiType>,
429        warnings: WarningReporter,
430    ) -> Self {
431        let mut result = Self {
432            name: name.to_owned(),
433            ..Default::default()
434        };
435
436        for r#type in types {
437            let r#type = if let OpenApiType {
438                ref_path: Some(ref_path),
439                ..
440            } = r#type
441            {
442                let Some(name) = ref_path.strip_prefix("#/components/schemas/") else {
443                    warnings.push(format!("Malformed ref {ref_path}"));
444                    continue;
445                };
446                let Some(schema) = schemas.get(name) else {
447                    warnings.push(format!("Missing schema for ref {name}"));
448                    continue;
449                };
450                schema
451            } else {
452                r#type
453            };
454            let obj = Self::from_schema_object(name, r#type, schemas, warnings.child("variant"));
455
456            result.description = result.description.or(obj.description);
457            result.properties.extend(obj.properties);
458        }
459
460        result
461    }
462
463    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
464        let doc = self.description.as_ref().map(|d| {
465            quote! {
466                #[doc = #d]
467            }
468        });
469
470        let mut namespace = ObjectNamespace {
471            object: self,
472            ident: None,
473            elements: Vec::default(),
474        };
475
476        let mut props = Vec::with_capacity(self.properties.len());
477        for (_, prop) in &self.properties {
478            props.push(prop.codegen(&mut namespace, resolved)?);
479        }
480
481        let name = format_ident!("{}", self.name);
482        let ns = namespace.codegen();
483
484        Some(quote! {
485            #ns
486
487            #doc
488            #[derive(Debug, Clone, PartialEq, serde::Deserialize)]
489            pub struct #name {
490                #(#props),*
491            }
492        })
493    }
494}
495
496pub struct ObjectNamespace<'o> {
497    object: &'o Object,
498    ident: Option<Ident>,
499    elements: Vec<TokenStream>,
500}
501
502impl ObjectNamespace<'_> {
503    pub fn get_ident(&mut self) -> Ident {
504        self.ident
505            .get_or_insert_with(|| {
506                let name = self.object.name.to_snake_case();
507                format_ident!("{name}")
508            })
509            .clone()
510    }
511
512    pub fn push_element(&mut self, el: TokenStream) {
513        self.elements.push(el);
514    }
515
516    pub fn codegen(mut self) -> Option<TokenStream> {
517        if self.elements.is_empty() {
518            None
519        } else {
520            let ident = self.get_ident();
521            let elements = self.elements;
522            Some(quote! {
523                pub mod #ident {
524                    #(#elements)*
525                }
526            })
527        }
528    }
529}
530
531#[cfg(test)]
532mod test {
533    use super::*;
534
535    use crate::openapi::schema::test::get_schema;
536
537    #[test]
538    fn resolve_objects() {
539        let schema = get_schema();
540
541        let mut objects = 0;
542        let mut unresolved = vec![];
543
544        for (name, desc) in &schema.components.schemas {
545            if desc.r#type == Some("object") {
546                objects += 1;
547                let reporter = WarningReporter::new();
548                Object::from_schema_object(
549                    name,
550                    desc,
551                    &schema.components.schemas,
552                    reporter.clone(),
553                );
554                if !reporter.is_empty() {
555                    unresolved.push(name);
556                }
557            }
558        }
559
560        if !unresolved.is_empty() {
561            panic!(
562                "Failed to resolve {}/{} objects. Could not resolve [{}]",
563                unresolved.len(),
564                objects,
565                unresolved
566                    .into_iter()
567                    .map(|u| format!("`{u}`"))
568                    .collect::<Vec<_>>()
569                    .join(", ")
570            )
571        }
572    }
573}