torn_api_codegen/model/
enum.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7    parameter::OpenApiParameterSchema,
8    r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15    U8,
16    U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21    Ref { ty_name: String },
22    ArrayOfRefs { ty_name: String },
23    Primitive(PrimitiveType),
24    Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29        match schema {
30            OpenApiType {
31                ref_path: Some(path),
32                ..
33            } => Some(Self::Ref {
34                ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35            }),
36            OpenApiType {
37                r#type: Some("array"),
38                items: Some(items),
39                ..
40            } => {
41                let OpenApiType {
42                    ref_path: Some(path),
43                    ..
44                } = items.as_ref()
45                else {
46                    return None;
47                };
48                Some(Self::ArrayOfRefs {
49                    ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50                })
51            }
52            OpenApiType {
53                r#type: Some("string"),
54                format: None,
55                r#enum: None,
56                ..
57            } => Some(Self::Primitive(PrimitiveType::String)),
58            OpenApiType {
59                r#type: Some("string"),
60                format: None,
61                r#enum: Some(_),
62                ..
63            } => {
64                let name = format!("{name}Variant");
65                Some(Self::Enum {
66                    inner: Enum::from_schema(&name, schema)?,
67                    name,
68                })
69            }
70            OpenApiType {
71                r#type: Some("integer"),
72                format: Some("int64"),
73                ..
74            } => Some(Self::Primitive(PrimitiveType::I64)),
75            OpenApiType {
76                r#type: Some("integer"),
77                format: Some("int32"),
78                ..
79            } => Some(Self::Primitive(PrimitiveType::I32)),
80            _ => None,
81        }
82    }
83
84    pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
85        match self {
86            Self::Ref { ty_name } => {
87                let ty = format_ident!("{ty_name}");
88                quote! { crate::models::#ty }
89            }
90            Self::ArrayOfRefs { ty_name } => {
91                let ty = format_ident!("{ty_name}");
92                quote! { Vec<crate::models::#ty> }
93            }
94            Self::Primitive(PrimitiveType::I64) => quote! { i64 },
95            Self::Primitive(PrimitiveType::I32) => quote! { i32 },
96            Self::Primitive(PrimitiveType::Float) => quote! { f32 },
97            Self::Primitive(PrimitiveType::String) => quote! { String },
98            Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
99            Self::Primitive(PrimitiveType::Bool) => quote! { bool },
100            Self::Enum { name, .. } => {
101                let path = ns.get_ident();
102                let ty_name = format_ident!("{name}");
103                quote! {
104                    #path::#ty_name,
105                }
106            }
107        }
108    }
109
110    pub fn name(&self) -> String {
111        match self {
112            Self::Ref { ty_name } => ty_name.clone(),
113            Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
114            Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
115            Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
116            Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
117            Self::Primitive(PrimitiveType::String) => "String".to_owned(),
118            Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
119            Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
120            Self::Enum { .. } => "Variant".to_owned(),
121        }
122    }
123
124    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
125        match self {
126            Self::Primitive(_) => true,
127            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
128                .models
129                .get(ty_name)
130                .map(|f| f.is_display(resolved))
131                .unwrap_or_default(),
132            Self::Enum { inner, .. } => inner.is_display(resolved),
133        }
134    }
135
136    pub fn codegen_display(&self) -> TokenStream {
137        match self {
138            Self::ArrayOfRefs { .. } => quote! {
139                write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
140            },
141            _ => quote! {
142                write!(f, "{}", value)
143            },
144        }
145    }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub enum EnumVariantValue {
150    Repr(u32),
151    String { rename: Option<String> },
152    Tuple(Vec<EnumVariantTupleValue>),
153}
154
155impl Default for EnumVariantValue {
156    fn default() -> Self {
157        Self::String { rename: None }
158    }
159}
160
161impl EnumVariantValue {
162    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
163        match self {
164            Self::Repr(_) | Self::String { .. } => true,
165            Self::Tuple(val) => {
166                val.len() == 1
167                    && val
168                        .iter()
169                        .next()
170                        .map(|v| v.is_display(resolved))
171                        .unwrap_or_default()
172            }
173        }
174    }
175
176    pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
177        let variant = format_ident!("{name}");
178
179        match self {
180            Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
181            Self::String { rename } => {
182                let name = rename.as_deref().unwrap_or(name);
183                Some(quote! { Self::#variant => write!(f, #name) })
184            }
185            Self::Tuple(values) if values.len() == 1 => {
186                let rhs = values.first().unwrap().codegen_display();
187                Some(quote! { Self::#variant(value) => #rhs })
188            }
189            _ => None,
190        }
191    }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Default)]
195pub struct EnumVariant {
196    pub name: String,
197    pub description: Option<String>,
198    pub value: EnumVariantValue,
199}
200
201pub struct EnumNamespace<'e> {
202    r#enum: &'e Enum,
203    ident: Option<Ident>,
204    elements: Vec<TokenStream>,
205    top_level_elements: Vec<TokenStream>,
206}
207
208impl EnumNamespace<'_> {
209    pub fn get_ident(&mut self) -> Ident {
210        self.ident
211            .get_or_insert_with(|| {
212                let name = self.r#enum.name.to_snake_case();
213                format_ident!("{name}")
214            })
215            .clone()
216    }
217
218    pub fn push_element(&mut self, el: TokenStream) {
219        self.elements.push(el);
220    }
221
222    pub fn push_top_level(&mut self, el: TokenStream) {
223        self.top_level_elements.push(el);
224    }
225
226    pub fn codegen(mut self) -> Option<TokenStream> {
227        if self.elements.is_empty() && self.top_level_elements.is_empty() {
228            None
229        } else {
230            let top_level = &self.top_level_elements;
231            let mut output = quote! {
232                #(#top_level)*
233            };
234
235            if !self.elements.is_empty() {
236                let ident = self.get_ident();
237                let elements = self.elements;
238                output.extend(quote! {
239                    pub mod #ident {
240                        #(#elements)*
241                    }
242                });
243            }
244
245            Some(output)
246        }
247    }
248}
249
250impl EnumVariant {
251    pub fn codegen(
252        &self,
253        ns: &mut EnumNamespace,
254        resolved: &ResolvedSchema,
255    ) -> Option<TokenStream> {
256        let doc = self.description.as_ref().map(|d| {
257            quote! {
258                #[doc = #d]
259            }
260        });
261
262        let name = format_ident!("{}", self.name);
263
264        match &self.value {
265            EnumVariantValue::Repr(repr) => Some(quote! {
266                #doc
267                #name = #repr
268            }),
269            EnumVariantValue::String { rename } => {
270                let serde_attr = rename.as_ref().map(|r| {
271                    quote! {
272                        #[serde(rename = #r)]
273                    }
274                });
275
276                Some(quote! {
277                    #doc
278                    #serde_attr
279                    #name
280                })
281            }
282            EnumVariantValue::Tuple(values) => {
283                let mut val_tys = Vec::with_capacity(values.len());
284
285                if let [value] = values.as_slice() {
286                    let enum_name = format_ident!("{}", ns.r#enum.name);
287                    let ty_name = value.type_name(ns);
288
289                    ns.push_top_level(quote! {
290                        impl From<#ty_name> for #enum_name {
291                            fn from(value: #ty_name) -> Self {
292                                Self::#name(value)
293                            }
294                        }
295                    });
296                }
297
298                for value in values {
299                    let ty_name = value.type_name(ns);
300
301                    if let EnumVariantTupleValue::Enum { inner, .. } = &value {
302                        ns.push_element(inner.codegen(resolved)?);
303                    }
304
305                    val_tys.push(ty_name);
306                }
307
308                Some(quote! {
309                    #name(#(#val_tys),*)
310                })
311            }
312        }
313    }
314
315    pub fn codegen_display(&self) -> Option<TokenStream> {
316        self.value.codegen_display(&self.name)
317    }
318}
319
320#[derive(Debug, Clone, PartialEq, Eq, Default)]
321pub struct Enum {
322    pub name: String,
323    pub description: Option<String>,
324    pub repr: Option<EnumRepr>,
325    pub copy: bool,
326    pub untagged: bool,
327    pub variants: Vec<EnumVariant>,
328}
329
330impl Enum {
331    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
332        let mut result = Enum {
333            name: name.to_owned(),
334            description: schema.description.as_deref().map(ToOwned::to_owned),
335            copy: true,
336            ..Default::default()
337        };
338
339        match &schema.r#enum {
340            Some(OpenApiVariants::Int(int_variants)) => {
341                result.repr = Some(EnumRepr::U32);
342                result.variants = int_variants
343                    .iter()
344                    .copied()
345                    .map(|i| EnumVariant {
346                        name: format!("Variant{i}"),
347                        value: EnumVariantValue::Repr(i as u32),
348                        ..Default::default()
349                    })
350                    .collect();
351            }
352            Some(OpenApiVariants::Str(str_variants)) => {
353                result.variants = str_variants
354                    .iter()
355                    .copied()
356                    .map(|s| {
357                        let transformed = s.replace('&', "And").to_upper_camel_case();
358                        EnumVariant {
359                            value: EnumVariantValue::String {
360                                rename: (transformed != s).then(|| s.to_owned()),
361                            },
362                            name: transformed,
363                            ..Default::default()
364                        }
365                    })
366                    .collect();
367            }
368            None => return None,
369        }
370
371        Some(result)
372    }
373
374    pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
375        let mut result = Self {
376            name: name.to_owned(),
377            copy: true,
378            ..Default::default()
379        };
380
381        for var in schema.r#enum.as_ref()? {
382            let transformed = var.to_upper_camel_case();
383            result.variants.push(EnumVariant {
384                value: EnumVariantValue::String {
385                    rename: (transformed != *var).then(|| transformed.clone()),
386                },
387                name: transformed,
388                ..Default::default()
389            });
390        }
391
392        Some(result)
393    }
394
395    pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
396        let mut result = Self {
397            name: name.to_owned(),
398            untagged: true,
399            ..Default::default()
400        };
401
402        for schema in schemas {
403            let value = EnumVariantTupleValue::from_schema(name, schema)?;
404            let name = value.name();
405
406            result.variants.push(EnumVariant {
407                name,
408                value: EnumVariantValue::Tuple(vec![value]),
409                ..Default::default()
410            });
411        }
412
413        Some(result)
414    }
415
416    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
417        self.variants.iter().all(|v| v.value.is_display(resolved))
418    }
419
420    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
421        let repr = self.repr.map(|r| match r {
422            EnumRepr::U8 => quote! { #[repr(u8)]},
423            EnumRepr::U32 => quote! { #[repr(u32)]},
424        });
425        let name = format_ident!("{}", self.name);
426        let desc = self.description.as_ref().map(|d| {
427            quote! {
428                #repr
429                #[doc = #d]
430            }
431        });
432
433        let mut ns = EnumNamespace {
434            r#enum: self,
435            ident: None,
436            elements: Default::default(),
437            top_level_elements: Default::default(),
438        };
439
440        let is_display = self.is_display(resolved);
441
442        let mut display = Vec::with_capacity(self.variants.len());
443        let mut variants = Vec::with_capacity(self.variants.len());
444        for variant in &self.variants {
445            variants.push(variant.codegen(&mut ns, resolved)?);
446
447            if is_display {
448                display.push(variant.codegen_display()?);
449            }
450        }
451
452        let mut derives = vec![];
453
454        if self.repr.is_some() {
455            derives.push(quote! { serde_repr::Deserialize_repr });
456        } else {
457            derives.push(quote! { serde::Deserialize });
458        }
459
460        if self.copy {
461            derives.push(quote! { Copy, Hash });
462        }
463
464        let serde_attr = self.untagged.then(|| {
465            quote! {
466                #[serde(untagged)]
467            }
468        });
469
470        let display = is_display.then(|| {
471            quote! {
472                impl std::fmt::Display for #name {
473                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
474                        match self {
475                            #(#display),*
476                        }
477                    }
478                }
479            }
480        });
481
482        let module = ns.codegen();
483
484        Some(quote! {
485            #desc
486            #[derive(Debug, Clone, PartialEq, #(#derives),*)]
487            #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
488            #serde_attr
489            pub enum #name {
490                #(#variants),*
491            }
492            #display
493
494            #module
495        })
496    }
497}
498
499#[cfg(test)]
500mod test {
501    use super::*;
502
503    use crate::openapi::schema::test::get_schema;
504
505    #[test]
506    fn is_display() {
507        let schema = get_schema();
508        let resolved = ResolvedSchema::from_open_api(&schema);
509
510        let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
511        assert!(torn_selection_name.is_display(&resolved));
512    }
513}